diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h index a4f6ced16fa03..5cbc96a33d173 100644 --- a/clang/include/clang/Sema/SemaSYCL.h +++ b/clang/include/clang/Sema/SemaSYCL.h @@ -65,7 +65,8 @@ class SYCLIntegrationHeader { kind_work_group_memory, kind_dynamic_work_group_memory, kind_dynamic_accessor, - kind_last = kind_dynamic_accessor + kind_struct_with_special_type, // structs that contain special types + kind_last = kind_struct_with_special_type }; public: @@ -118,6 +119,9 @@ class SYCLIntegrationHeader { /// integration header is required. void addHostPipeRegistration() { NeedToEmitHostPipeRegistration = true; } + /// Set the ParentStruct field + void setParentStruct(ParmVarDecl *parent); + private: // Kernel actual parameter descriptor. struct KernelParamDesc { @@ -205,6 +209,20 @@ class SYCLIntegrationHeader { /// Keeps track of whether declaration of __sycl_host_pipe_registration /// type and __sycl_host_pipe_registrar variable are required to emit. bool NeedToEmitHostPipeRegistration = false; + + // For free function kernels, keeps track of the parameter that is currently + // being analyzed if it is a struct that contains special types. + ParmVarDecl *ParentStruct = nullptr; + + // For every struct that contains a special type which is given by + // the ParentStruct field above, record the offset and size of its fields + // at any nesting level. Store the information in the variable below. + llvm::DenseMap>> + OffsetSizeInfo; + // Likewise for the kind of a field i.e accessor, std_layout etc... + llvm::DenseMap> + KindInfo; }; class SYCLIntegrationFooter { @@ -267,6 +285,10 @@ class SemaSYCL : public SemaBase { llvm::DenseSet FreeFunctionDeclarations; + // A map that keeps track of all structs encountered with + // special types inside. Relevant for free function kernels only. + llvm::DenseSet StructsWithSpecialTypes; + public: SemaSYCL(Sema &S); @@ -317,6 +339,13 @@ class SemaSYCL : public SemaBase { SYCLKernelFunctions.insert(FD); } + /// Add ParentStruct to StructsWithSpecialTypes. + void addStructWithSpecialType(const RecordDecl *ParentStruct) { + StructsWithSpecialTypes.insert(ParentStruct); + } + + auto &getStructsWithSpecialType() const { return StructsWithSpecialTypes; } + /// Lazily creates and returns SYCL integration header instance. SYCLIntegrationHeader &getSyclIntegrationHeader() { if (SyclIntHeader == nullptr) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 996aca877350c..3cad865f7afd2 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -1779,6 +1779,11 @@ class SyclKernelFieldHandler : public SyclKernelFieldHandlerBase { SemaSYCL &SemaSYCLRef; SyclKernelFieldHandler(SemaSYCL &S) : SemaSYCLRef(S) {} + // Holds the last handled kernel struct parameter that contains a special + // type. Set in the enterStruct functions. Only relevant for free function + // kernels + ParmVarDecl *ParentStruct = nullptr; + // Returns 'true' if the thing we're visiting (Based on the FD/QualType pair) // is an element of an array. FD will always be the array field. When // traversing the array field, Ty will be the type of the array field or the @@ -2189,31 +2194,12 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { } bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO manipulate struct depth once special types are supported for free - // function kernels. - // ++StructFieldDepth; return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType ParamTy) final { - // TODO manipulate struct depth once special types are supported for free - // function kernels. - // --StructFieldDepth; - // TODO We don't yet support special types and therefore structs that - // require decomposition and leaving/entering. Diagnose for better user - // experience. - CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); - if (RD->hasAttr()) { - Diag.Report(PD->getLocation(), - diag::err_bad_kernel_param_type) - << ParamTy; - Diag.Report(PD->getLocation(), - diag::note_free_function_kernel_param_type_not_supported) - << ParamTy; - IsInvalid = true; - } - return isValid(); + return true; } bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, @@ -2327,8 +2313,6 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { } bool handleSyclSpecialType(ParmVarDecl *, QualType) final { - // TODO We don't support special types in free function kernel parameters, - // but track them to diagnose the case properly. CollectionStack.back() = true; return true; } @@ -2598,9 +2582,8 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, + QualType ParamTy) final { return true; } @@ -2618,9 +2601,8 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } - bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, + QualType ParamTy) final { return true; } @@ -2692,9 +2674,7 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } - bool handleScalarType(ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool handleScalarType(ParmVarDecl *PD, QualType ParamTy) final { return true; } @@ -2714,10 +2694,8 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } - bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *, - QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD, + QualType ParamTy) final { return true; } @@ -3019,9 +2997,11 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - // ++StructDepth; + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType Ty) final { + ++StructDepth; + StringRef Name = "_arg_struct"; + addParam(Name, Ty); + ParentStruct = Params.back(); return true; } @@ -3031,8 +3011,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - // --StructDepth; + --StructDepth; return true; } @@ -3222,6 +3201,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return ArrayRef(std::begin(Params) + LastParamIndex, std::end(Params)); } + ParmVarDecl *getParentStruct() { return ParentStruct; } }; // This Visitor traverses the AST of the function with @@ -4400,16 +4380,18 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { SyclKernelDeclCreator &DeclCreator; llvm::SmallVector BodyStmts; + // Keep track of the structs we have encountered on our way to a special type. + // They will be needed to properly generate the __init call. Note that the + // top-level struct parameter is not kept track here because that is done by + // the DeclCreator. + llvm::SmallVector CurrentStructs; FunctionDecl *FreeFunc = nullptr; SourceLocation FreeFunctionSrcLoc; // Free function source location. llvm::SmallVector ArgExprs; - // Creates a DeclRefExpr to the ParmVar that represents the current free - // function parameter. - Expr *createParamReferenceExpr() { - ParmVarDecl *FreeFunctionParameter = - DeclCreator.getParamVarDeclsForCurrentField()[0]; - + // Creates a DeclRefExpr to the ParmVar that represents an arbitrary + // free function parameter + Expr *createParamReferenceExpr(ParmVarDecl *FreeFunctionParameter) { QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType(); Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr( FreeFunctionParameter, FreeFunctionParamType, VK_LValue, @@ -4418,6 +4400,14 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { return DRE; } + // Creates a DeclRefExpr to the ParmVar that represents the current free + // function parameter. + Expr *createParamReferenceExpr() { + ParmVarDecl *FreeFunctionParameter = + DeclCreator.getParamVarDeclsForCurrentField()[0]; + return createParamReferenceExpr(FreeFunctionParameter); + } + // Creates a DeclRefExpr to the ParmVar that represents the current pointer // parameter. Expr *createPointerParamReferenceExpr(QualType PointerTy) { @@ -4564,9 +4554,21 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { DeclCreator.setBody(KernelBody); } - bool handleSyclSpecialType(FieldDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool handleSyclSpecialType(FieldDecl *FD, QualType FieldTy) final { + // FD represents a special type which is a field of a struct parameter + // passed to a free function kernel Get this struct parameter using + // getParentStruct and build the __init call. Also add the struct to the + // list of special structs needed later by the integration header to + // generate some helper structs for the runtime. + Expr *Base = createParamReferenceExpr(DeclCreator.getParentStruct()); + for (const auto &child : CurrentStructs) { + Base = buildMemberExpr(Base, child); + } + MemberExpr *MemberAccess = buildMemberExpr(Base, FD); + createSpecialMethodCall(FieldTy->getAsCXXRecordDecl(), InitMethodName, + MemberAccess, BodyStmts); + SemaSYCLRef.addStructWithSpecialType( + DeclCreator.getParentStruct()->getType()->getAsCXXRecordDecl()); return true; } @@ -4575,8 +4577,8 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { // typically if this is the case the default constructor will be private and // in such cases we must manually override the access specifier from private // to public just for the duration of this default initialization. - // TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061 - // is closed. + // TODO: Revisit this approach once + // https://github.com/intel/llvm/issues/16061 is closed. bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { // The code produced looks like this in the case of a work group memory // parameter: @@ -4669,11 +4671,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { return true; } - bool handleScalarType(FieldDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); - return true; - } + bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { return true; } bool handleScalarType(ParmVarDecl *, QualType) final { Expr *ParamRef = createParamReferenceExpr(); @@ -4693,27 +4691,25 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { return true; } - bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool enterStruct(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final { + CurrentStructs.push_back(FD); return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool enterStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, + QualType ParamTy) final { return true; } - bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { + CurrentStructs.pop_back(); return true; } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + ArgExprs.push_back(SemaSYCLRef.SemaRef.BuildDeclRefExpr( + DeclCreator.getParentStruct(), DeclCreator.getParentStruct()->getType(), + VK_PRValue, FreeFunctionSrcLoc)); return true; } @@ -4754,6 +4750,11 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { unsupportedFreeFunctionParamType(); return true; } + FieldDecl *getCurrentStruct() { + assert(CurrentStructs.size() && + "Current free function parameter is not inside a structure!"); + return CurrentStructs.back(); + } }; // Kernels are only the unnamed-lambda feature if the feature is enabled, AND @@ -4796,13 +4797,9 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { addParam(ArgTy, Kind, offsetOf(FD, ArgTy)); } - // For free functions we increment the current offset as each parameter is - // added. void addParam(const ParmVarDecl *PD, QualType ParamTy, SYCLIntegrationHeader::kernel_param_kind_t Kind) { addParam(ParamTy, Kind, offsetOf(PD, ParamTy)); - CurOffset += - SemaSYCLRef.getASTContext().getTypeSizeInChars(ParamTy).getQuantity(); } void addParam(QualType ParamTy, @@ -4986,8 +4983,8 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { } bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { - // Arrays are always wrapped inside of structs, so just treat it as a simple - // struct. + // Arrays are always wrapped inside of structs, so just treat it as a + // simple struct. addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); return true; } @@ -5043,9 +5040,9 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } - bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *PD, QualType Ty) final { + addParam(PD, Ty, SYCLIntegrationHeader::kind_struct_with_special_type); + Header.setParentStruct(PD); return true; } @@ -5056,8 +5053,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { } bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { - // TODO - unsupportedFreeFunctionParamType(); + Header.setParentStruct(nullptr); return true; } @@ -6149,6 +6145,7 @@ static const char *paramKind2Str(KernelParamKind K) { CASE(work_group_memory); CASE(dynamic_work_group_memory); CASE(dynamic_accessor); + CASE(struct_with_special_type); } return ""; @@ -7185,6 +7182,10 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { unsigned ShimCounter = 1; int FreeFunctionCount = 0; + // Structs with special types inside needs some special code generation in the + // header and we keep this visited map to not have duplicates in case several + // free function kernels use the same struct type as parameters. + llvm::DenseMap visitedStructWithSpecialType; for (const KernelDesc &K : KernelDescs) { if (!S.isFreeFunction(K.SyclKernel)) continue; @@ -7270,6 +7271,67 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames); } + // Now we handle all structs that contain special types + // inside. Their information is contained in StructsWithSpecialTypes of + // SemaSYCL. + for (ParmVarDecl *Param : K.SyclKernel->parameters()) { + if (!Param->getType()->isStructureType()) + continue; + const RecordDecl *Struct = Param->getType()->getAsRecordDecl(); + QualType type = Param->getType(); + if (!S.getStructsWithSpecialType().count(Struct) || + visitedStructWithSpecialType.count(Struct)) + continue; + + FwdDeclEmitter.Visit(type.getDesugaredType(S.getASTContext())); + + // this is a struct that contains a special type so its neither a + // special type nor a trivially copyable type. We therefore need to + // explicitly communicate to the runtime that this argument should be + // allowed as a free function kernel argument. We do this by defining + // is_struct_with_special_type to be true. This helper struct also + // contains information about the offset, size and parameter + // kind of every field inside the struct at any nesting level + // This facilitates setting the arguments in the runtime. + // We also define is_device_copyable trait to be true for this type to + // allow it being passed in device kernels. + O << "template <>\n"; + O << "struct " + "sycl::is_device_copyable<"; + Policy.SuppressTagKeyword = true; + type.print(O, Policy); + O << ">: std::true_type {};\n"; + + O << "template <>\n"; + O << "struct " + "sycl::ext::oneapi::experimental::detail::is_struct_with_special_" + "type<"; + Policy.SuppressTagKeyword = true; + type.print(O, Policy); + O << "> {\n"; + O << " inline static constexpr bool value = true;\n"; + O << " static constexpr int offsets[] = { "; + for (const auto OffsetSize : OffsetSizeInfo[Param]) { + O << OffsetSize.first << ", "; + } + O << "-1};\n "; + + O << " static constexpr int sizes[] = { "; + for (const auto OffsetSize : OffsetSizeInfo[Param]) { + O << OffsetSize.second << ", "; + } + O << "-1}; \n "; + + O << " static constexpr sycl::detail::kernel_param_kind_t kinds[] = {\n "; + for (const auto Kind : KindInfo[Param]) { + O << " sycl::detail::kernel_param_kind_t::" << paramKind2Str(Kind); + O << ",\n "; + } + O << "sycl::detail::kernel_param_kind_t::kind_invalid }; \n};\n\n "; + + visitedStructWithSpecialType[Struct] = true; + } + Policy.SuppressTagKeyword = false; FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList); O << ";\n"; O << "}\n"; @@ -7355,6 +7417,21 @@ void SYCLIntegrationHeader::addParamDesc(kernel_param_kind_t Kind, int Info, PD.Kind = Kind; PD.Info = Info; PD.Offset = Offset; + // If we are adding a free function kernel parameter that is a struct that + // contains a special type, a little more work needs to be done in order to + // help the runtime set the kernel arguments properly. Add the offset, size, + // and Kind information to the integration header for each field inside this + // struct. Also, verify that we are actually adding a field and not the struct + // itself by checking the Kind. + if (ParentStruct && + Kind != kernel_param_kind_t::kind_struct_with_special_type) { + OffsetSizeInfo[ParentStruct].emplace_back(std::make_pair(Offset, Info)); + KindInfo[ParentStruct].emplace_back(Kind); + } +} + +void SYCLIntegrationHeader::setParentStruct(ParmVarDecl *parent) { + ParentStruct = parent; } void SYCLIntegrationHeader::endKernel() { diff --git a/clang/test/CodeGenSYCL/free_function_int_header.cpp b/clang/test/CodeGenSYCL/free_function_int_header.cpp index 4fe7a761e98c6..250ac721ab583 100644 --- a/clang/test/CodeGenSYCL/free_function_int_header.cpp +++ b/clang/test/CodeGenSYCL/free_function_int_header.cpp @@ -278,6 +278,49 @@ void ff_24(int arg); void ff_24(int arg) { } +// Tests with parameter types that are structs that contain special types inside e.g accessor + +struct AccessorAndLocalAccessor { + sycl::accessor acc; + sycl::local_accessor lacc; +}; + +struct AccessorAndInt { + sycl::accessor acc; + int a; +}; + +struct IntAndAccessor { + int a; + sycl::accessor acc; +}; + +struct SecondLevelAccessor { + AccessorAndInt accAndInt; +}; + +template +struct TemplatedAccessorStruct { + sycl::accessor acc; + sycl::local_accessor lacc; +}; + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void ff_25(AccessorAndLocalAccessor arg1) { +} + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void ff_26(AccessorAndLocalAccessor arg1, SecondLevelAccessor arg2) { +} + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void ff_27(IntAndAccessor arg1, AccessorAndInt) { +} + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void ff_28(TemplatedAccessorStruct arg1) { +} + // CHECK: const char* const kernel_names[] = { // CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii // CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piiii @@ -313,6 +356,11 @@ void ff_24(int arg) { // CHECK-NEXT: {{.*}}__sycl_kernel_ff_217DerivedPS_ // CHECK-NEXT: {{.*}}__sycl_kernel_ff_227DerivedPS_ // CHECK-NEXT: {{.*}}__sycl_kernel_ff_24i" +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2524AccessorAndLocalAccessor", +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2624AccessorAndLocalAccessor19SecondLevelAccessor", +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2714IntAndAccessor14AccessorAndInt", +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2823TemplatedAccessorStructIiE", + // CHECK-NEXT: {{.*}}__sycl_kernel_ff_23i" // CHECK-NEXT: "" @@ -321,39 +369,39 @@ void ff_24(int arg) { // CHECK: const kernel_param_desc_t kernel_signatures[] = { // CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii // CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK: {{.*}}__sycl_kernel_ff_2Piiii // CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 16 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK: {{.*}}__sycl_kernel_ff_3IiEvPT_S0_S0_ // CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK: {{.*}}__sycl_kernel_ff_3IfEvPT_S0_S0_ // CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK: {{.*}}__sycl_kernel_ff_3IdEvPT_S0_S0_ // CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 16 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 0 }, // CHECK: //--- _Z18__sycl_kernel_ff_410NoPointers8Pointers3Agg // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 16, 4 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 32, 20 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 16, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 32, 0 }, // CHECK: //--- _Z18__sycl_kernel_ff_6I3Agg7DerivedEvT_T0_i // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 32, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 40, 32 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 72 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 40, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK: //--- _Z18__sycl_kernel_ff_7ILi3EEv16KArgWithPtrArrayIXT_EE // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 48, 0 }, @@ -364,27 +412,27 @@ void ff_24(int arg) { // CHECK: //--- _ZN28__sycl_kernel_free_functions4ff_9EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5tests5ff_10EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5tests2V15ff_11EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN26__sycl_kernel__GLOBAL__N_15ff_12EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5ff_13EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5tests5ff_13EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _Z18__sycl_kernel_ff_9N4sycl3_V125dynamic_work_group_memoryIiEE // CHECK-NEXT: { kernel_param_kind_t::kind_dynamic_work_group_memory, 8, 0 }, @@ -409,23 +457,23 @@ void ff_24(int arg) { // CHECK: //--- _ZN28__sycl_kernel_free_functions5tests5ff_14EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5ff_15EiPi // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5ff_16E3AggPS0_ // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 32, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 32 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5ff_17E7DerivedPS0_ // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 40, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 40 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _ZN28__sycl_kernel_free_functions5tests5ff_18ENS_3AggEPS1_ // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, // CHECK: //--- _Z19__sycl_kernel_ff_19N14free_functions16KArgWithPtrArrayILi50EEE // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 800, 0 }, @@ -436,6 +484,32 @@ void ff_24(int arg) { // CHECK: //--- _Z19__sycl_kernel_ff_24i // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK: //--- _Z19__sycl_kernel_ff_2524AccessorAndLocalAccessor +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 36, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4064, 12 }, + +// CHECK: //--- _Z19__sycl_kernel_ff_2624AccessorAndLocalAccessor19SecondLevelAccessor +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 36, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4064, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 16, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, + +// CHECK: //--- _Z19__sycl_kernel_ff_2714IntAndAccessor14AccessorAndInt +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 16, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 16, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, + +// CHECK: //--- _Z19__sycl_kernel_ff_2823TemplatedAccessorStructIiE +// CHECK-NEXT: { kernel_param_kind_t::kind_struct_with_special_type, 36, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4064, 12 }, + // CHECK: //--- _Z19__sycl_kernel_ff_23i // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, @@ -1531,18 +1605,147 @@ void ff_24(int arg) { // CHECK-NEXT: static constexpr bool value = true; // CHECK-NEXT: }; +// CHECK: Definition of _Z19__sycl_kernel_ff_2524AccessorAndLocalAccessor as a free function kernel +// CHECK: Forward declarations of kernel and its argument types: +// CHECK: void ff_25(AccessorAndLocalAccessor arg1); +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::is_device_copyable: std::true_type {}; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::ext::oneapi::experimental::detail::is_struct_with_special_type { +// CHECK-NEXT: inline static constexpr bool value = true; +// CHECK-NEXT: static constexpr int offsets[] = { 0, 12, -1}; +// CHECK-NEXT: static constexpr int sizes[] = { 4062, 4064, -1}; +// CHECK-NEXT: static constexpr sycl::detail::kernel_param_kind_t kinds[] = { +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_invalid }; +// CHECK-NEXT: }; + +// CHECK: static constexpr auto __sycl_shim33() { +// CHECK-NEXT: return (void (*)(struct AccessorAndLocalAccessor))ff_25; +// CHECK-NEXT: } + +// CHECK: struct ext::oneapi::experimental::is_kernel<__sycl_shim33()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim33()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; + +// CHECK: Definition of _Z19__sycl_kernel_ff_2624AccessorAndLocalAccessor19SecondLevelAccessor as a free function kernel +// CHECK: Forward declarations of kernel and its argument types: +// CHECK: void ff_26(AccessorAndLocalAccessor arg1, SecondLevelAccessor arg2); +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::is_device_copyable: std::true_type {}; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::ext::oneapi::experimental::detail::is_struct_with_special_type { +// CHECK-NEXT: inline static constexpr bool value = true; +// CHECK-NEXT: static constexpr int offsets[] = { 0, 12, -1}; +// CHECK-NEXT: static constexpr int sizes[] = { 4062, 4, -1}; +// CHECK-NEXT: static constexpr sycl::detail::kernel_param_kind_t kinds[] = { +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_std_layout, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_invalid }; +// CHECK-NEXT: }; + +// CHECK: static constexpr auto __sycl_shim34() { +// CHECK-NEXT: return (void (*)(struct AccessorAndLocalAccessor, struct SecondLevelAccessor))ff_26; +// CHECK-NEXT: } + +// CHECK: struct ext::oneapi::experimental::is_kernel<__sycl_shim34()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim34()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT }; + +// CHECK: Definition of _Z19__sycl_kernel_ff_2714IntAndAccessor14AccessorAndInt as a free function kernel +// CHECK: Forward declarations of kernel and its argument types: +// CHECK: void ff_27(IntAndAccessor arg1, AccessorAndInt ); +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::is_device_copyable: std::true_type {}; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::ext::oneapi::experimental::detail::is_struct_with_special_type { +// CHECK-NEXT: inline static constexpr bool value = true; +// CHECK-NEXT: static constexpr int offsets[] = { 0, 4, -1}; +// CHECK-NEXT: static constexpr int sizes[] = { 4, 4062, -1}; +// CHECK-NEXT: static constexpr sycl::detail::kernel_param_kind_t kinds[] = { +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_std_layout, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_invalid }; +// CHECK-NEXT: }; + +// CHECK: template <> +// CHECK-NEXT: struct sycl::is_device_copyable: std::true_type {}; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::ext::oneapi::experimental::detail::is_struct_with_special_type { +// CHECK-NEXT: inline static constexpr bool value = true; +// CHECK-NEXT: static constexpr int offsets[] = { 0, 12, -1}; +// CHECK-NEXT: static constexpr int sizes[] = { 4062, 4, -1}; +// CHECK-NEXT: static constexpr sycl::detail::kernel_param_kind_t kinds[] = { +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_std_layout, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_invalid }; +// CHECK-NEXT: }; + + + +// CHECK: static constexpr auto __sycl_shim35() { +// CHECK-NEXT: return (void (*)(struct IntAndAccessor, struct AccessorAndInt))ff_27; +// CHECK-NEXT: } + +// CHECK: struct ext::oneapi::experimental::is_kernel<__sycl_shim35()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim35()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; + + +// CHECK: Definition of _Z19__sycl_kernel_ff_2823TemplatedAccessorStructIiE as a free function kernel +// CHECK: Forward declarations of kernel and its argument types: +// CHECK: template struct TemplatedAccessorStruct; +// CHECK: void ff_28(TemplatedAccessorStruct arg1); +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::is_device_copyable>: std::true_type {}; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct sycl::ext::oneapi::experimental::detail::is_struct_with_special_type> { +// CHECK-NEXT: inline static constexpr bool value = true; +// CHECK-NEXT: static constexpr int offsets[] = { 0, 12, -1}; +// CHECK-NEXT: static constexpr int sizes[] = { 4062, 4064, -1}; +// CHECK-NEXT: static constexpr sycl::detail::kernel_param_kind_t kinds[] = { +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT: sycl::detail::kernel_param_kind_t::kind_accessor, +// CHECK-NEXT sycl::detail::kernel_param_kind_t::kind_invalid }; +// CHECK-NEXT: }; + +// CHECK: static constexpr auto __sycl_shim36() { +// CHECK-NEXT: return (void (*)(struct TemplatedAccessorStruct))ff_28; +// CHECK-NEXT: } + +// CHECK: struct ext::oneapi::experimental::is_kernel<__sycl_shim36()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; +// CHECK-NEXT: template <> +// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim36()> { +// CHECK-NEXT: static constexpr bool value = true; +// CHECK-NEXT: }; + // CHECK: Definition of _Z19__sycl_kernel_ff_23i as a free function kernel // CHECK: Forward declarations of kernel and its argument types: // CHECK: void ff_23(int arg); -// CHECK-NEXT: static constexpr auto __sycl_shim33() { +// CHECK-NEXT: static constexpr auto __sycl_shim37() { // CHECK-NEXT: return (void (*)(int))ff_23; // CHECK-NEXT: } // CHECK: namespace sycl { // CHECK-NEXT: inline namespace _V1 { // CHECK-NEXT: namespace detail { -// CHECK-NEXT: //Free Function Kernel info specialization for shim33 -// CHECK-NEXT: template <> struct FreeFunctionInfoData<__sycl_shim33()> { +// CHECK-NEXT: //Free Function Kernel info specialization for shim37 +// CHECK-NEXT: template <> struct FreeFunctionInfoData<__sycl_shim37()> { // CHECK-NEXT: __SYCL_DLL_LOCAL // CHECK-NEXT: static constexpr unsigned getNumParams() { return 1; } // CHECK-NEXT: __SYCL_DLL_LOCAL @@ -1554,11 +1757,11 @@ void ff_24(int arg) { // CHECK: namespace sycl { // CHECK-NEXT: template <> -// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim33()> { +// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim37()> { // CHECK-NEXT: static constexpr bool value = true; // CHECK-NEXT: }; // CHECK-NEXT: template <> -// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim33()> { +// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim37()> { // CHECK-NEXT: static constexpr bool value = true; // CHECK-NEXT: }; @@ -1572,7 +1775,7 @@ void ff_24(int arg) { // CHECK-NEXT: namespace detail { // CHECK-NEXT: struct GlobalMapUpdater { // CHECK-NEXT: GlobalMapUpdater() { -// CHECK-NEXT: sycl::detail::free_function_info_map::add(sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, 33); +// CHECK-NEXT: sycl::detail::free_function_info_map::add(sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, 37); // CHECK-NEXT: } // CHECK-NEXT: }; // CHECK-NEXT: static GlobalMapUpdater updater; diff --git a/clang/test/SemaSYCL/free_function_kernel_params_restrictions.cpp b/clang/test/SemaSYCL/free_function_kernel_params_restrictions.cpp index d1bdc0e3da475..c7b2d2de8921c 100644 --- a/clang/test/SemaSYCL/free_function_kernel_params_restrictions.cpp +++ b/clang/test/SemaSYCL/free_function_kernel_params_restrictions.cpp @@ -42,20 +42,3 @@ __attribute__((sycl_device)) [[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] void ff_5(A S1) { } - - - -struct StructWithAccessor { - sycl::accessor acc; - int *ptr; -}; - -struct Wrapper { - StructWithAccessor SWA; - -}; - -[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] -void ff_6(Wrapper S1) { // expected-error {{cannot be used as the type of a kernel parameter}} - // expected-note@-1 {{'Wrapper' is not yet supported as a free function kernel parameter}} -} diff --git a/sycl/include/sycl/detail/kernel_desc.hpp b/sycl/include/sycl/detail/kernel_desc.hpp index 2e6f5fdad5f80..e3134accc29f2 100644 --- a/sycl/include/sycl/detail/kernel_desc.hpp +++ b/sycl/include/sycl/detail/kernel_desc.hpp @@ -61,7 +61,8 @@ enum class kernel_param_kind_t { kind_work_group_memory = 6, kind_dynamic_work_group_memory = 7, kind_dynamic_accessor = 8, - kind_invalid = 0xf, // not a valid kernel kind + kind_struct_with_special_type = 9, // structs that contain special types + kind_invalid = 0xf, // not a valid kernel kind }; // describes a kernel parameter diff --git a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp index 2b5d1f4190d21..f399c380fd5f8 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #pragma once +#include +#include namespace sycl { inline namespace _V1 { @@ -44,6 +46,26 @@ template struct is_kernel { template inline constexpr bool is_kernel_v = is_kernel::value; +namespace detail { +// A struct with special type is a struct type that contains special types +// passed as a paremeter to a free function kernel. It is decomposed into its +// consituents by the frontend which puts the relevant informaton about each of +// them into the struct below, namely offset, size and parameter kind for each +// one of them. The runtime then calls the addArg function to add each one of +// them as kernel arguments. The value bool is used to distinguish these structs +// from ordinary e.g standard layout structs. +template struct is_struct_with_special_type { + static constexpr bool value = false; + static constexpr int offsets[] = {-1}; + static constexpr int sizes[] = {-1}; + static constexpr sycl::detail::kernel_param_kind_t kinds[] = { + sycl::detail::kernel_param_kind_t::kind_invalid}; +}; + +} // namespace detail } // namespace ext::oneapi::experimental + +template struct is_device_copyable; + } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 67f21bc05857f..4e716b25943e3 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -691,6 +692,10 @@ class __SYCL_EXPORT handler { if (!std::is_same::value && std::is_pointer::value) { addArg(detail::kernel_param_kind_t::kind_pointer, StoredArg, sizeof(T), ArgIndex); + } else if (ext::oneapi::experimental::detail::is_struct_with_special_type< + remove_cv_ref_t>::value) { + addArg(detail::kernel_param_kind_t::kind_struct_with_special_type, + StoredArg, sizeof(T), ArgIndex); } else { addArg(detail::kernel_param_kind_t::kind_std_layout, StoredArg, sizeof(T), ArgIndex); @@ -1632,7 +1637,10 @@ class __SYCL_EXPORT handler { || (!is_same_type::value && std::is_pointer_v>) // USM || is_same_type::value // Interop - || is_same_type::value; // Stream + || is_same_type::value // Stream + || + sycl::is_device_copyable_v>; // Structs that contain + // special types }; /// Sets argument for OpenCL interoperability kernels. @@ -1645,6 +1653,51 @@ class __SYCL_EXPORT handler { typename std::enable_if_t::value, void> set_arg(int ArgIndex, T &&Arg) { setArgHelper(ArgIndex, std::move(Arg)); + ++ArgIndex; + // The following concerns free function kernels only. + // if we are dealing with a struct parameter that contains special types + // inside, we call addArg for each field of the struct(special and standard + // layout included) at any nesting level using the information provided by + // the frontend with the arrays offsets, sizes, and kinds which as the name + // suggests, provide the offset, size and kind of each such field. + if constexpr (ext::oneapi::experimental::detail:: + is_struct_with_special_type>::value) { + using type = + ext::oneapi::experimental::detail::is_struct_with_special_type< + remove_cv_ref_t>; + int NumArgs = 0; + while (type::offsets[NumArgs] != -1) { + void *FieldArg = (char *)(&Arg) + type::offsets[NumArgs]; + // treat accessors separately since we have to fetch the data ptr and + // pass that to the addArg function rather than the address of the + // accessor object itself. + if (type::kinds[NumArgs] == + detail::kernel_param_kind_t::kind_accessor) { + constexpr int AccessTargetMask = 0x7ff; + const access::target target = static_cast( + type::sizes[NumArgs] & AccessTargetMask); + if (target == target::local) { + detail::LocalAccessorBaseHost *LocalAccBase = + (detail::LocalAccessorBaseHost *)(FieldArg); + setLocalAccessorArgHelper(ArgIndex + NumArgs, *LocalAccBase); + } else { + detail::AccessorBaseHost *AccBase = + (detail::AccessorBaseHost *)(FieldArg); + const detail::AccessorImplPtr &AccImpl = + detail::getSyclObjImpl(*AccBase); + detail::AccessorImplHost *Req = AccImpl.get(); + addArg(type::kinds[NumArgs], Req, type::sizes[NumArgs], + ArgIndex + NumArgs); + } + } else { + // for non-accessors, simply call addArg normally. + addArg(type::kinds[NumArgs], FieldArg, type::sizes[NumArgs], + ArgIndex + NumArgs); + } + ++NumArgs; + } + incrementArgShift(NumArgs); + } } template (args)...); } - void clearArgs() { MArgs.clear(); } + void clearArgs() { + MArgs.clear(); + MArgShift = 0; + } detail::NDRDescT &getNDRDesc() & { return MNDRDesc; } @@ -181,6 +184,10 @@ class KernelData { void extractArgsAndReqsFromLambda(); + void incrementArgShift(int Shift); + + int getArgShift() const; + private: // Storage for any SYCL Graph dynamic parameters which have been flagged for // registration in the CG, along with the argument index for the parameter. @@ -204,6 +211,14 @@ class KernelData { // A pointer to device kernel information. Cached on the application side in // headers or retrieved from program manager. DeviceKernelInfo *MDeviceKernelInfoPtr = nullptr; + + // Certain arguments such as structs that contain SYCL special types entail + // several hidden set_arg calls for every set_arg called by the user. This + // shift is required to make sure the following arguments set by the user have + // the correct index. It keeps track of how many of these hidden set_arg calls + // have been made so far. The user cannot possibly know this, hence we need to + // keep track of this information. + int MArgShift = 0; }; } // namespace detail diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 142d8f958f1ea..8596090a325c6 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -2348,6 +2348,11 @@ static void SetArgBasedOnType( break; } + case kernel_param_kind_t::kind_struct_with_special_type: { + Adapter.call(Kernel, NextTrueIndex, + Arg.MSize, nullptr, Arg.MPtr); + break; + } case kernel_param_kind_t::kind_sampler: { sampler *SamplerPtr = (sampler *)Arg.MPtr; ur_sampler_handle_t Sampler = diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index c881d4da9efb9..ed715724aeac8 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -2245,7 +2245,8 @@ void handler::addLifetimeSharedPtrStorage(std::shared_ptr SPtr) { void handler::addArg(detail::kernel_param_kind_t ArgKind, void *Req, int AccessTarget, int ArgIndex) { - impl->MKernelData.addArg(ArgKind, Req, AccessTarget, ArgIndex); + impl->MKernelData.addArg(ArgKind, Req, AccessTarget, + ArgIndex + impl->MKernelData.getArgShift()); } #ifndef __INTEL_PREVIEW_BREAKING_CHANGES @@ -2403,6 +2404,10 @@ void handler::setDeviceKernelInfoPtr( impl->MKernelData.setDeviceKernelInfoPtr(DeviceKernelInfoPtr); } +void handler::incrementArgShift(int Shift) { + impl->MKernelData.incrementArgShift(Shift); +} + void handler::setKernelFunc(void *KernelFuncPtr) { impl->MKernelData.setKernelFunc(KernelFuncPtr); } diff --git a/sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp b/sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp index 72f4ca099fee3..81019f1e548c6 100644 --- a/sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp +++ b/sycl/test-e2e/FreeFunctionKernels/structs_with_special_types_as_kernel_paramters.cpp @@ -4,9 +4,6 @@ // This test verifies whether struct that contains either sycl::local_accesor or // sycl::accessor can be used with free function kernels extension. -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-67737 - #include #include #include diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index ec2a21ef34424..b570960838a36 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3647,6 +3647,7 @@ _ZN4sycl3_V17handler8getQueueEv _ZN4sycl3_V17handler8prefetchEPKvm _ZN4sycl3_V17handler8prefetchEPKvmNS0_3ext6oneapi12experimental13prefetch_typeE _ZN4sycl3_V17handler9clearArgsEv +_ZN4sycl3_V17handler17incrementArgShiftEi _ZN4sycl3_V17handler9fill_implEPvPKvmm _ZN4sycl3_V17handlerC1EOSt10unique_ptrINS0_6detail12handler_implESt14default_deleteIS4_EE _ZN4sycl3_V17handlerC1ESt10shared_ptrINS0_3ext6oneapi12experimental6detail10graph_implEE diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index 8eed8b8dba437..61b2e18ecf814 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -3837,6 +3837,7 @@ ?category@exception@_V1@sycl@@QEBAAEBVerror_category@std@@XZ ?checkNodePropertiesAndThrow@modifiable_command_graph@detail@experimental@oneapi@ext@_V1@sycl@@KAXAEBVproperty_list@67@@Z ?clearArgs@handler@_V1@sycl@@AEAAXXZ +?incrementArgShift@handler@_V1@sycl@@AEAAXH@Z ?code@exception@_V1@sycl@@QEBAAEBVerror_code@std@@XZ ?compile_from_source@detail@experimental@oneapi@ext@_V1@sycl@@YA?AV?$kernel_bundle@$00@56@AEAV?$kernel_bundle@$02@56@AEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@std@@AEBV?$vector@Vstring_view@detail@_V1@sycl@@V?$allocator@Vstring_view@detail@_V1@sycl@@@std@@@std@@PEAVstring@156@2@Z ?compile_impl@detail@_V1@sycl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@_V1@sycl@@@std@@AEBV?$kernel_bundle@$0A@@23@AEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@5@AEBVproperty_list@23@@Z diff --git a/sycl/test/include_deps/sycl_detail_core.hpp.cpp b/sycl/test/include_deps/sycl_detail_core.hpp.cpp index f4c33d1ed938f..9b356b40c0c53 100644 --- a/sycl/test/include_deps/sycl_detail_core.hpp.cpp +++ b/sycl/test/include_deps/sycl_detail_core.hpp.cpp @@ -150,6 +150,7 @@ // CHECK-NEXT: ext/oneapi/interop_common.hpp // CHECK-NEXT: ext/oneapi/bindless_images_mem_handle.hpp // CHECK-NEXT: ext/oneapi/experimental/cluster_group_prop.hpp +// CHECK-NEXT: ext/oneapi/experimental/free_function_traits.hpp // CHECK-NEXT: ext/oneapi/experimental/raw_kernel_arg.hpp // CHECK-NEXT: ext/oneapi/experimental/use_root_sync_prop.hpp // CHECK-NEXT: kernel.hpp