Skip to content

Commit 2beda8e

Browse files
committed
Stash changes
1 parent 236139f commit 2beda8e

File tree

4 files changed

+184
-56
lines changed

4 files changed

+184
-56
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 159 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
// This implements Semantic Analysis for SYCL constructs.
99
//===----------------------------------------------------------------------===//
10-
10+
#include <iostream>
1111
#include "clang/Sema/SemaSYCL.h"
1212
#include "TreeTransform.h"
1313
#include "clang/AST/AST.h"
@@ -1521,8 +1521,9 @@ class KernelObjVisitor {
15211521
template <typename... HandlerTys>
15221522
void visitParam(ParmVarDecl *Param, QualType ParamTy,
15231523
HandlerTys &...Handlers) {
1524-
if (isSyclSpecialType(ParamTy, SemaSYCLRef))
1525-
KP_FOR_EACH(handleOtherType, Param, ParamTy);
1524+
if (isSyclSpecialType(ParamTy, SemaSYCLRef)){
1525+
KP_FOR_EACH(handleSyclSpecialType, Param, ParamTy);
1526+
}
15261527
else if (ParamTy->isStructureOrClassType()) {
15271528
if (KP_FOR_EACH(handleStructType, Param, ParamTy)) {
15281529
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
@@ -2070,8 +2071,28 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20702071
}
20712072

20722073
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
2073-
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy;
2074-
IsInvalid = true;
2074+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
2075+
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
2076+
// For free functions all struct/class kernel arguments are forward declared
2077+
// in integration header, that adds additional restrictions for kernel
2078+
// arguments.
2079+
NotForwardDeclarableReason NFDR =
2080+
isForwardDeclarable(RD, SemaSYCLRef, /*DiagForFreeFunction=*/true);
2081+
if (NFDR != NotForwardDeclarableReason::None) {
2082+
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
2083+
<< ParamTy;
2084+
Diag.Report(PD->getLocation(),
2085+
diag::note_free_function_kernel_param_type_not_fwd_declarable)
2086+
<< ParamTy;
2087+
IsInvalid = true;
2088+
} else
2089+
IsInvalid = false;
2090+
}
2091+
else {
2092+
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
2093+
<< ParamTy;
2094+
IsInvalid = true;
2095+
}
20752096
return isValid();
20762097
}
20772098

@@ -2224,6 +2245,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22242245

22252246
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
22262247
// TODO
2248+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) return true;
22272249
unsupportedFreeFunctionParamType();
22282250
return true;
22292251
}
@@ -2262,7 +2284,7 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
22622284
return true;
22632285
}
22642286

2265-
bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
2287+
bool handleSyclSpecialType(ParmVarDecl *, QualType Ty) final {
22662288
// TODO We don't support special types in free function kernel parameters,
22672289
// but track them to diagnose the case properly.
22682290
CollectionStack.back() = true;
@@ -3008,10 +3030,29 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30083030
return handleSpecialType(FD, FieldTy);
30093031
}
30103032

3011-
bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
3033+
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
30123034
// TODO
3013-
unsupportedFreeFunctionParamType();
3014-
return true;
3035+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
3036+
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
3037+
assert(RecordDecl && "The type must be a RecordDecl");
3038+
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
3039+
assert(InitMethod && "The type must have the __init method");
3040+
3041+
// Don't do -1 here because we count on this to be the first parameter
3042+
// added (if any).
3043+
size_t ParamIndex = Params.size();
3044+
for (const ParmVarDecl *Param : InitMethod->parameters()) {
3045+
QualType ParamTy = Param->getType();
3046+
addParam(Param, ParamTy.getCanonicalType());
3047+
// Propagate add_ir_attributes_kernel_parameter attribute.
3048+
if (const auto *AddIRAttr =
3049+
Param->getAttr<SYCLAddIRAttributesKernelParameterAttr>())
3050+
Params.back()->addAttr(AddIRAttr->clone(SemaSYCLRef.getASTContext()));
3051+
}
3052+
LastParamIndex = ParamIndex;
3053+
} else
3054+
unsupportedFreeFunctionParamType();
3055+
return true;
30153056
}
30163057

30173058
RecordDecl *wrapField(FieldDecl *Field, QualType FieldTy) {
@@ -3286,9 +3327,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
32863327
}
32873328

32883329
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
3289-
// TODO
3290-
unsupportedFreeFunctionParamType();
3291-
return true;
3330+
return handleSpecialType(ParamTy);
32923331
}
32933332

32943333
bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -3601,6 +3640,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
36013640

36023641
BodyStmts.insert(BodyStmts.end(), FinalizeStmts.begin(),
36033642
FinalizeStmts.end());
3643+
//CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts,
3644+
// FPOptionsOverride(), {}, {})->dumpPretty(SemaSYCLRef.getASTContext());
36043645

36053646
return CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts,
36063647
FPOptionsOverride(), {}, {});
@@ -4118,6 +4159,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
41184159
annotateHierarchicalParallelismAPICalls();
41194160
CompoundStmt *KernelBody = createKernelBody();
41204161
DeclCreator.setBody(KernelBody);
4162+
DeclCreator.getKernelDecl()->dump();
41214163
}
41224164

41234165
bool handleSyclSpecialType(FieldDecl *FD, QualType Ty) final {
@@ -4308,15 +4350,11 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
43084350
SourceLocation FreeFunctionSrcLoc; // Free function source location.
43094351
llvm::SmallVector<Expr *, 8> ArgExprs;
43104352

4311-
// Creates a DeclRefExpr to the ParmVar that represents the current free
4312-
// function parameter.
4313-
Expr *createParamReferenceExpr() {
4314-
ParmVarDecl *FreeFunctionParameter =
4315-
DeclCreator.getParamVarDeclsForCurrentField()[0];
4316-
4317-
QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType();
4353+
// Creates a DeclRefExpr to the ParmVar given be PD
4354+
Expr *createParamReferenceExpr(ParmVarDecl *PD = nullptr) {
4355+
ParmVarDecl *FreeFunctionParameter = PD ? PD : DeclCreator.getParamVarDeclsForCurrentField()[0];
43184356
Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
4319-
FreeFunctionParameter, FreeFunctionParamType, VK_LValue,
4357+
FreeFunctionParameter, FreeFunctionParameter->getType(), VK_LValue,
43204358
FreeFunctionSrcLoc);
43214359
DRE = SemaSYCLRef.SemaRef.DefaultLvalueConversion(DRE).get();
43224360
return DRE;
@@ -4412,8 +4450,52 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44124450
auto CallExpr = CallExpr::Create(Context, Fn, ArgExprs, ResultTy, VK,
44134451
FreeFunctionSrcLoc, FPOptionsOverride());
44144452
BodyStmts.push_back(CallExpr);
4415-
return CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {},
4416-
{});
4453+
// CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {},
4454+
// {})->dumpPretty(SemaSYCLRef.getASTContext());
4455+
return CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {},
4456+
{});
4457+
}
4458+
4459+
MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) {
4460+
DeclAccessPair MemberDAP = DeclAccessPair::make(Member, AS_none);
4461+
MemberExpr *Result = SemaSYCLRef.SemaRef.BuildMemberExpr(
4462+
Base, /*IsArrow */ false, FreeFunctionSrcLoc, NestedNameSpecifierLoc(),
4463+
FreeFunctionSrcLoc, Member, MemberDAP,
4464+
/*HadMultipleCandidates*/ false,
4465+
DeclarationNameInfo(Member->getDeclName(), FreeFunctionSrcLoc),
4466+
Member->getType(), VK_LValue, OK_Ordinary);
4467+
return Result;
4468+
}
4469+
4470+
void createSpecialMethodCall(const CXXRecordDecl *RD, StringRef MethodName,
4471+
Expr *MemberBaseExpr,
4472+
SmallVectorImpl<Stmt *> &AddTo) {
4473+
CXXMethodDecl *Method = getMethodByName(RD, MethodName);
4474+
if (!Method)
4475+
return;
4476+
4477+
unsigned NumParams = Method->getNumParams();
4478+
llvm::SmallVector<Expr *, 4> ParamDREs(NumParams);
4479+
llvm::ArrayRef<ParmVarDecl *> KernelParameters =
4480+
DeclCreator.getParamVarDeclsForCurrentField();
4481+
for (size_t I = 0; I < NumParams; ++I) {
4482+
QualType ParamType = KernelParameters[I]->getOriginalType();
4483+
ParamDREs[I] = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
4484+
KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
4485+
}
4486+
4487+
MemberExpr *MethodME = buildMemberExpr(MemberBaseExpr, Method);
4488+
4489+
QualType ResultTy = Method->getReturnType();
4490+
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
4491+
ResultTy = ResultTy.getNonLValueExprType(SemaSYCLRef.getASTContext());
4492+
llvm::SmallVector<Expr *, 4> ParamStmts;
4493+
const auto *Proto = cast<FunctionProtoType>(Method->getType());
4494+
SemaSYCLRef.SemaRef.GatherArgumentsForCall(FreeFunctionSrcLoc, Method,
4495+
Proto, 0, ParamDREs, ParamStmts);
4496+
AddTo.push_back(CXXMemberCallExpr::Create(
4497+
SemaSYCLRef.getASTContext(), MethodME, ParamStmts, ResultTy, VK,
4498+
FreeFunctionSrcLoc, FPOptionsOverride()));
44174499
}
44184500

44194501
public:
@@ -4427,6 +4509,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44274509
~FreeFunctionKernelBodyCreator() {
44284510
CompoundStmt *KernelBody = createFreeFunctionKernelBody();
44294511
DeclCreator.setBody(KernelBody);
4512+
DeclCreator.getKernelDecl()->dump();
44304513
}
44314514

44324515
bool handleSyclSpecialType(FieldDecl *FD, QualType Ty) final {
@@ -4435,9 +4518,41 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44354518
return true;
44364519
}
44374520

4438-
bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
4521+
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
44394522
// TODO
4440-
unsupportedFreeFunctionParamType();
4523+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
4524+
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
4525+
QualType Ty = PD->getOriginalType();
4526+
ASTContext &Ctx = SemaSYCLRef.SemaRef.getASTContext();
4527+
VarDecl *WorkGroupMemoryClone = VarDecl::Create(
4528+
Ctx, DeclCreator.getKernelDecl(), FreeFunctionSrcLoc, FreeFunctionSrcLoc, PD->getIdentifier(),
4529+
PD->getType(),
4530+
Ctx.getTrivialTypeSourceInfo(Ty), SC_None);
4531+
InitializedEntity VarEntity =
4532+
InitializedEntity::InitializeVariable(WorkGroupMemoryClone);
4533+
InitializationKind InitKind =
4534+
InitializationKind::CreateDefault(FreeFunctionSrcLoc);
4535+
InitializationSequence InitSeq(SemaSYCLRef.SemaRef, VarEntity, InitKind,
4536+
std::nullopt);
4537+
ExprResult Init = InitSeq.Perform(SemaSYCLRef.SemaRef, VarEntity,
4538+
InitKind, std::nullopt);
4539+
WorkGroupMemoryClone->setInit(
4540+
SemaSYCLRef.SemaRef.MaybeCreateExprWithCleanups(Init.get()));
4541+
WorkGroupMemoryClone->setInitStyle(VarDecl::CallInit);
4542+
Stmt *DS = new (SemaSYCLRef.getASTContext())
4543+
DeclStmt(DeclGroupRef(WorkGroupMemoryClone), FreeFunctionSrcLoc,
4544+
FreeFunctionSrcLoc);
4545+
BodyStmts.push_back(DS);
4546+
Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
4547+
WorkGroupMemoryClone, Ty, VK_LValue, FreeFunctionSrcLoc);
4548+
//createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr,
4549+
//BodyStmts);
4550+
Expr *RvalueMemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
4551+
WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4552+
4553+
ArgExprs.push_back(RvalueMemberBaseExpr);
4554+
} else
4555+
unsupportedFreeFunctionParamType();
44414556
return true;
44424557
}
44434558

@@ -4717,9 +4832,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
47174832
return true;
47184833
}
47194834

4720-
bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
4835+
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
47214836
// TODO
4722-
unsupportedFreeFunctionParamType();
4837+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
4838+
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4839+
else unsupportedFreeFunctionParamType();
47234840
return true;
47244841
}
47254842

@@ -4907,7 +5024,6 @@ class SYCLKernelNameTypeVisitor
49075024
void Visit(QualType T) {
49085025
if (T.isNull())
49095026
return;
4910-
49115027
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
49125028
// If KernelNameType has template args visit each template arg via
49135029
// ConstTemplateArgumentVisitor
@@ -5566,11 +5682,9 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
55665682
KernelObjVisitor Visitor{*this};
55675683

55685684
DiagnosingSYCLKernel = true;
5569-
55705685
// Check parameters of free function.
55715686
Visitor.VisitFunctionParameters(FD, DecompMarker, FieldChecker,
55725687
UnionChecker);
5573-
55745688
DiagnosingSYCLKernel = false;
55755689

55765690
// Ignore the free function if any of the checkers fail validation.
@@ -5911,7 +6025,7 @@ class SYCLFwdDeclEmitter
59116025
void Visit(QualType T) {
59126026
if (T.isNull())
59136027
return;
5914-
InnerTypeVisitor::Visit(T.getTypePtr());
6028+
InnerTypeVisitor::Visit(T.getTypePtr());
59156029
}
59166030

59176031
void VisitReferenceType(const ReferenceType *RT) {
@@ -5935,7 +6049,7 @@ class SYCLFwdDeclEmitter
59356049
}
59366050

59376051
void VisitTagType(const TagType *T) {
5938-
TagDecl *TD = T->getDecl();
6052+
TagDecl *TD = T->getDecl();
59396053
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
59406054
// - first, recurse into template parameters and emit needed forward
59416055
// declarations
@@ -6015,13 +6129,13 @@ class SYCLFwdDeclEmitter
60156129
VisitTemplateArgs(TA.getPackAsArray());
60166130
}
60176131

6018-
void VisitFunctionProtoType(const FunctionProtoType *T) {
6019-
for (const auto Ty : T->getParamTypes())
6132+
void VisitFunctionProtoType(const FunctionProtoType *T) {
6133+
for (const auto Ty : T->getParamTypes())
60206134
Visit(Ty.getCanonicalType());
60216135
// So far this visitor method is only used for free function kernels whose
60226136
// return type is void anyway, so it is not visited. Otherwise, add if
60236137
// required.
6024-
}
6138+
}
60256139
};
60266140

60276141
class SYCLKernelNameTypePrinter
@@ -6196,6 +6310,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
61966310
O << "#include <sycl/detail/defines_elementary.hpp>\n";
61976311
O << "#include <sycl/detail/kernel_desc.hpp>\n";
61986312
O << "#include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n";
6313+
O << "#include <sycl/ext/oneapi/properties/properties.hpp>\n";
6314+
O << "namespace sycl { inline namespace _V1 { namespace ext { namespace "
6315+
"oneapi { namespace experimental {\n"
6316+
" template<typename DataT, typename PropertiesT = "
6317+
"properties<std::tuple<>>> class work_group_memory;\n"
6318+
"}\n"
6319+
"}\n"
6320+
"}\n"
6321+
"}\n"
6322+
"}\n";
61996323

62006324
O << "\n";
62016325

@@ -6465,8 +6589,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
64656589
O << "\n";
64666590
O << "// Forward declarations of kernel and its argument types:\n";
64676591
FwdDeclEmitter.Visit(K.SyclKernel->getType());
6468-
O << "\n";
6469-
6592+
O << "\n";
64706593
if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
64716594
O << "extern \"C\" ";
64726595
std::string ParmList;

sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ class work_group_memory_impl {
3333
friend class sycl::handler;
3434
};
3535

36-
} // namespace detail
3736

37+
} // namespace detail
3838
namespace ext::oneapi::experimental {
39+
#ifdef __SYCL_DEVICE_ONLY__
3940
template <typename DataT, typename PropertyListT = empty_properties_t>
41+
#else
42+
template <typename DataT, typename PropertyListT>
43+
#endif
4044
class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
4145
: sycl::detail::work_group_memory_impl {
4246
public:
@@ -79,7 +83,15 @@ class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
7983
#endif
8084
private:
8185
decoratedPtr ptr;
86+
size_t bufferSize;
87+
template <typename DataType>
88+
friend size_t getWorkGroupMemorySize();
8289
};
90+
91+
template <typename DataType>
92+
size_t getWorkGroupMemorySize() {
93+
return work_group_memory<DataType, empty_properties_t>::bufferSize;
94+
}
8395
} // namespace ext::oneapi::experimental
8496
} // namespace _V1
8597
} // namespace sycl

0 commit comments

Comments
 (0)