Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 140 additions & 29 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,12 @@ class KernelObjVisitor {
std::initializer_list<int>{(result = result && tn(BD, BDTy), 0)...};
return result;
}
template <typename... Tn>
bool handleField(ParmVarDecl *PD, QualType PDTy, Tn &&...tn) {
bool result = true;
std::initializer_list<int>{(result = result && tn(PD, PDTy), 0)...};
return result;
}

// This definition using std::bind is necessary because of a gcc 7.x bug.
#define KF_FOR_EACH(FUNC, Item, Qt) \
Expand Down Expand Up @@ -1443,9 +1449,12 @@ class KernelObjVisitor {
HandlerTys &...Handlers) {
if (isSyclSpecialType(ParamTy, SemaSYCLRef))
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isStructureOrClassType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isUnionType())
else if (ParamTy->isStructureOrClassType()) {
if (KF_FOR_EACH(handleStructType, Param, ParamTy)) {
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
visitRecord(nullptr, Param, RD, ParamTy, Handlers...);
}
} else if (ParamTy->isUnionType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isReferenceType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
Expand Down Expand Up @@ -1957,8 +1966,29 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

bool handleStructType(ParmVarDecl *PD, QualType ParamTy) final {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy;
IsInvalid = true;
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
// For free functions all struct/class kernel arguments are forward declared
// in integration header, that adds additional restrictions for kernel
// arguments.
// Lambdas are not forward declarable. So, diagnose them properly.
if (RD->isLambda()) {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
<< ParamTy;
IsInvalid = true;
return isValid();
}

// Check that the type is defined at namespace scope.
const DeclContext *DeclCtx = RD->getDeclContext();
while (!DeclCtx->isTranslationUnit() &&
(isa<NamespaceDecl>(DeclCtx) || isa<LinkageSpecDecl>(DeclCtx)))
DeclCtx = DeclCtx->getParent();

if (!DeclCtx->isTranslationUnit()) {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
<< ParamTy;
IsInvalid = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would make sense to factor this out to an isForwardDeclarable() function? I looked for an existing one, but didn't find one. There are several cases where a forward declarable declaration is required and I'm already skeptical that we're diagnosing violations correctly. See DiagnoseKernelNameType() for additional code that could be factored out and merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return isValid();
}

Expand Down Expand Up @@ -2037,14 +2067,16 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO manipulate struct depth once special types are supported for free
// function kernels.
// ++StructFieldDepth;
Comment on lines +2137 to +2139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we at least diagnose cases that involve SYCL special types now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this is already diagnosed by calling handleOtherType during visitation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually maybe not. It is not obvious to me whether we will hit that code since we aren't decomposing yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we didn't diagnose because we don't decompose yet. I added diagnosing.

return true;
}

bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO manipulate struct depth once special types are supported for free
// function kernels.
// --StructFieldDepth;
return true;
}

Expand Down Expand Up @@ -2162,8 +2194,7 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
}

bool handlePointerType(ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
PointerStack.back() = targetRequiresNewType(SemaSYCLRef.getASTContext());
return true;
}

Expand Down Expand Up @@ -2194,8 +2225,10 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
}

bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO handle decomposition once special type arguments are supported
// for free function kernels.
// CollectionStack.push_back(false);
PointerStack.push_back(false);
return true;
}

Expand All @@ -2221,10 +2254,24 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
return true;
}

bool leaveStruct(const CXXRecordDecl *RD, ParmVarDecl *PD,
bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
assert(RD && "should not be null.");
// TODO handle decomposition once special type arguments are supported
// for free function kernels.
// if (CollectionStack.pop_back_val()) {
// if (!RD->hasAttr<SYCLRequiresDecompositionAttr>())
// RD->addAttr(SYCLRequiresDecompositionAttr::CreateImplicit(
// SemaSYCLRef.getASTContext()));
// CollectionStack.back() = true;
// PointerStack.pop_back();
if (PointerStack.pop_back_val()) {
PointerStack.back() = true;
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
RD->addAttr(SYCLGenerateNewTypeAttr::CreateImplicit(
SemaSYCLRef.getASTContext()));
}
return true;
}

Expand Down Expand Up @@ -2974,8 +3021,15 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *RD, ParmVarDecl *PD,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
// This is a struct parameter which should not be decomposed.
CXXRecordDecl *ParamRecordDecl = ParamTy->getAsCXXRecordDecl();
assert(ParamRecordDecl && "Type must be a C++ record type");
// Check if we need to generate a new type for this record,
// i.e. this record contains pointers.
if (ParamRecordDecl->hasAttr<SYCLGenerateNewTypeAttr>())
addParam(PD, GenerateNewRecordType(ParamRecordDecl));
else
addParam(PD, ParamTy);
return true;
}

Expand Down Expand Up @@ -3203,8 +3257,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
addParam(ParamTy);
return true;
}

Expand Down Expand Up @@ -4194,7 +4247,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {

// Creates a DeclRefExpr to the ParmVar that represents the current pointer
// parameter.
Expr *createPointerParamReferenceExpr(QualType PointerTy, bool Wrapped) {
Expr *createPointerParamReferenceExpr(QualType PointerTy) {
ParmVarDecl *FreeFunctionParameter =
DeclCreator.getParamVarDeclsForCurrentField()[0];

Expand All @@ -4212,6 +4265,50 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
return DRE;
}

Expr *createGetAddressOf(Expr *E) {
return UnaryOperator::Create(
SemaSYCLRef.getASTContext(), E, UO_AddrOf,
SemaSYCLRef.getASTContext().getPointerType(E->getType()), VK_PRValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
}

Expr *createDerefOp(Expr *E) {
return UnaryOperator::Create(SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
Comment on lines +4355 to +4358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with surrounding functions. Clang-format might have other ideas though.

Suggested change
return UnaryOperator::Create(SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
return UnaryOperator::Create(
SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang-format doesn't agree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format makes bad choices sometimes :)

}

Expr *createReinterpretCastExpr(Expr *E, QualType To) {
return CXXReinterpretCastExpr::Create(
SemaSYCLRef.getASTContext(), To, VK_PRValue, CK_BitCast, E,
/*Path=*/nullptr,
SemaSYCLRef.getASTContext().getTrivialTypeSourceInfo(To),
SourceLocation(), SourceLocation(), SourceRange());
}

Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name suggestion:

Suggested change
Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) {
Expr *createCopyInitExpr(ParmVarDecl *OrigFunctionParameter) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Expr *DRE = createParamReferenceExpr();

assert(OrigFunctionParameter && "no parameter?");

CXXRecordDecl *RD = OrigFunctionParameter->getType()->getAsCXXRecordDecl();
InitializedEntity Entity = InitializedEntity::InitializeParameter(
SemaSYCLRef.getASTContext(), OrigFunctionParameter);

if (RD->hasAttr<SYCLGenerateNewTypeAttr>()) {
DRE = createReinterpretCastExpr(
createGetAddressOf(DRE), SemaSYCLRef.getASTContext().getPointerType(
OrigFunctionParameter->getType()));
DRE = createDerefOp(DRE);
}

ExprResult ArgE = SemaSYCLRef.SemaRef.PerformCopyInitialization(
Entity, SourceLocation(), DRE, false, false);
return ArgE.getAs<Expr>();
}

// For a free function such as:
// void f(int i, int* p, struct Simple S) { ... }
//
Expand Down Expand Up @@ -4281,7 +4378,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
}

bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final {
Expr *PointerRef = createPointerParamReferenceExpr(ParamTy, false);
Expr *PointerRef = createPointerParamReferenceExpr(ParamTy);
ArgExprs.push_back(PointerRef);
return true;
}
Expand All @@ -4299,10 +4396,10 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *,
bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD,
QualType) final {
// TODO
unsupportedFreeFunctionParamType();
Expr *TempCopy = createStructTemporary(PD);
ArgExprs.push_back(TempCopy);
return true;
}

Expand Down Expand Up @@ -4588,8 +4685,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_std_layout);
return true;
}

Expand Down Expand Up @@ -5435,6 +5531,7 @@ void SemaSYCL::MarkDevices() {

void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (isFreeFunction(*this, FD)) {
SyclKernelDecompMarker DecompMarker(*this);
SyclKernelFieldChecker FieldChecker(*this);
SyclKernelUnionChecker UnionChecker(*this);

Expand All @@ -5443,7 +5540,8 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
DiagnosingSYCLKernel = true;

// Check parameters of free function.
Visitor.VisitFunctionParameters(FD, FieldChecker, UnionChecker);
Visitor.VisitFunctionParameters(FD, FieldChecker, UnionChecker,
DecompMarker);

DiagnosingSYCLKernel = false;

Expand Down Expand Up @@ -5889,6 +5987,14 @@ class SYCLFwdDeclEmitter
void VisitPackTemplateArgument(const TemplateArgument &TA) {
VisitTemplateArgs(TA.getPackAsArray());
}

void VisitFunctionProtoType(const FunctionProtoType *T) {
for (const auto Ty : T->getParamTypes())
Visit(Ty.getCanonicalType());
// So far this visitor method is only used for free function kernels whose
// return type is void anyway, so it is not visited. Otherwise, add if
// required.
}
};

class SYCLKernelNameTypePrinter
Expand Down Expand Up @@ -6325,10 +6431,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
continue;

++FreeFunctionCount;
// Generate forward declaration for free function.
O << "\n// Definition of " << K.Name << " as a free function kernel\n";

O << "\n";
O << "// Forward declarations of kernel and its argument types:\n";
FwdDeclEmitter.Visit(K.SyclKernel->getType());
O << "\n";

if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
O << "extern \"C\" ";
std::string ParmList;
Expand Down
Loading