Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 136 additions & 29 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,12 @@ class KernelObjVisitor {
std::initializer_list<int>{(result = result && tn(BD, BDTy), 0)...};
return result;
}
template <typename... Tn>
bool handleField(ParmVarDecl *PD, QualType PDTy, Tn &&...tn) {
bool result = true;
std::initializer_list<int>{(result = result && tn(PD, PDTy), 0)...};
return result;
}

// This definition using std::bind is necessary because of a gcc 7.x bug.
#define KF_FOR_EACH(FUNC, Item, Qt) \
Expand Down Expand Up @@ -1443,9 +1449,12 @@ class KernelObjVisitor {
HandlerTys &...Handlers) {
if (isSyclSpecialType(ParamTy, SemaSYCLRef))
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isStructureOrClassType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isUnionType())
else if (ParamTy->isStructureOrClassType()) {
if (KF_FOR_EACH(handleStructType, Param, ParamTy)) {
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
visitRecord(RD, Param, RD, ParamTy, Handlers...);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing RD for both the 1st and 3rd arguments seems surprising here. The situation doesn't seem quite analogous to visitField() above. I'm having a difficult time figuring out exactly what visitRecord() is actually intending to do; the owner/wrapper distinction seems weird to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a lot of code around the visitors is intended to have some "owner" because the original use case is lambda/functor whose fields need to be visited. However it doesn't seem to affect things anyhow and I suspect not having it is fine. Also this comment

// type (which doesn't exist in cases where it is a FieldDecl in the

suggests so.
I transformed this argument to nullptr to avoid confusion.

}
} else if (ParamTy->isUnionType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
else if (ParamTy->isReferenceType())
KP_FOR_EACH(handleOtherType, Param, ParamTy);
Expand Down Expand Up @@ -1957,8 +1966,25 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

bool handleStructType(ParmVarDecl *PD, QualType ParamTy) final {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy;
IsInvalid = true;
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
// For free functions all struct/class kernel arguments are forward declared
// in integration header, that adds additional restrictions for kernel
// arguments.
// Lambdas are not forward declarable. So, diagnose them properly.
if (RD->isLambda()) {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
<< ParamTy;
IsInvalid = true;
return isValid();
}

// Check that the type is defined at namespace scope.
const DeclContext *DeclCtx = RD->getDeclContext();
if (!DeclCtx->isTranslationUnit() && !isa<NamespaceDecl>(DeclCtx)) {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
<< ParamTy;
IsInvalid = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs updates to handle declarations in ExternCContextDecl and LinkageSpecDecl declaration contexts. We should presumably traverse through those to the enclosing TranslationUnitDecl or NamespaceDecl context. Tests for that would be good; the forward declaration in the integration header should reproduce the enclosing extern "C", extern "C++", etc... context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added handling for LinkageSpecDecl here, however I'm not able to write the code that would give me ExternCContextDecl to handle. Does clang emit it still?

the forward declaration in the integration header should reproduce the enclosing extern "C", extern "C++", etc... context.

I wonder, what would be the benefit of doing that? I suppose linkage declaration contexts shouldn't affect the name. I see that the code generating forward declarations is intentionally skipping LinkageSpecDecl .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into ExternCContextDecl and confirmed that it is never used as either a lexical or semantic context for a declaration (it is used to collect and identify extern "C" declarations that might appear in distinct lexical contexts). So, nothing to be done for it; the change to handle LinkageSpecDecl is all that is needed.

I wonder, what would be the benefit of doing that? I suppose linkage declaration contexts shouldn't affect the name. I see that the code generating forward declarations is intentionally skipping LinkageSpecDecl .

It could affect name mangling for variable and function declarations which in turn could affect mangling elsewhere. A highly contrived example involving kernel names is below. It looks like icx currently fails to handle such cases regardless of whether f() is declared extern "C" though; https://godbolt.org/z/KWhh5xWq9. In general, it looks like icx fails to generate correct integration headers for class templates with non-type template parameters that reference other symbols.

#include <sycl/sycl.hpp>
extern "C" void f();
template<void(*)()> class kernel_name {};
int main() {
  sycl::queue q;
  q.submit([](sycl::handler &h) {
	h.single_task<kernel_name<f>>([]{});
  });
  q.wait();
}

I don't know how important such cases are. Since SYCL doesn't support function pointers in device code, it could be useful to smuggle a function reference through the type system. icx accepts the following example: https://godbolt.org/z/TGW5Ynnjz.

#include <sycl/sycl.hpp>
extern "C" SYCL_EXTERNAL void f();
template<void(&FN)()> struct X {
  void operator()() const {
	FN();
  }
};
int main() {
  sycl::queue q;
  X<f> x;
  q.submit([=](sycl::handler &h) {
	h.single_task<struct KN>([=]{ x(); });
  });
  q.wait();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it seems the problem here is that we don't support forward declaring functions in the integration header. I'm not sure how useful the examples above, since plain SYCL doesn't allow function pointers anyway. I see that supporting these cases may require big functional changes around integration header generation, so I'm not sure if we should do this as a part of the PR. My preference is to add support for these changes as a separate PR. WDYT?

return isValid();
}

Expand Down Expand Up @@ -2037,14 +2063,16 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO manipulate struct depth once special types are supported for free
// function kernels.
// ++StructFieldDepth;
Comment on lines +2137 to +2139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we at least diagnose cases that involve SYCL special types now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this is already diagnosed by calling handleOtherType during visitation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually maybe not. It is not obvious to me whether we will hit that code since we aren't decomposing yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we didn't diagnose because we don't decompose yet. I added diagnosing.

return true;
}

bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO manipulate struct depth once special types are supported for free
// function kernels.
// --StructFieldDepth;
return true;
}

Expand Down Expand Up @@ -2162,8 +2190,7 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
}

bool handlePointerType(ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
PointerStack.back() = targetRequiresNewType(SemaSYCLRef.getASTContext());
return true;
}

Expand Down Expand Up @@ -2194,8 +2221,10 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
}

bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// TODO handle decomposition once special type arguments are supported
// for free function kernels.
// CollectionStack.push_back(false);
PointerStack.push_back(false);
return true;
}

Expand All @@ -2221,10 +2250,24 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
return true;
}

bool leaveStruct(const CXXRecordDecl *RD, ParmVarDecl *PD,
bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
assert(RD && "should not be null.");
// TODO handle decomposition once special type arguments are supported
// for free function kernels.
// if (CollectionStack.pop_back_val()) {
// if (!RD->hasAttr<SYCLRequiresDecompositionAttr>())
// RD->addAttr(SYCLRequiresDecompositionAttr::CreateImplicit(
// SemaSYCLRef.getASTContext()));
// CollectionStack.back() = true;
// PointerStack.pop_back();
if (PointerStack.pop_back_val()) {
PointerStack.back() = true;
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
RD->addAttr(SYCLGenerateNewTypeAttr::CreateImplicit(
SemaSYCLRef.getASTContext()));
}
return true;
}

Expand Down Expand Up @@ -2974,8 +3017,15 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *RD, ParmVarDecl *PD,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
// This is a field which should not be decomposed.
CXXRecordDecl *FieldRecordDecl = ParamTy->getAsCXXRecordDecl();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FieldRecordDecl doesn't seem like the right name here. Perhaps ParamRecordDecl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks for the catch!

assert(FieldRecordDecl && "Type must be a C++ record type");
// Check if we need to generate a new type for this record,
// i.e. this record contains pointers.
if (FieldRecordDecl->hasAttr<SYCLGenerateNewTypeAttr>())
addParam(PD, GenerateNewRecordType(FieldRecordDecl));
else
addParam(PD, ParamTy);
return true;
}

Expand Down Expand Up @@ -3203,8 +3253,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
addParam(ParamTy);
return true;
}

Expand Down Expand Up @@ -4194,7 +4243,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {

// Creates a DeclRefExpr to the ParmVar that represents the current pointer
// parameter.
Expr *createPointerParamReferenceExpr(QualType PointerTy, bool Wrapped) {
Expr *createPointerParamReferenceExpr(QualType PointerTy) {
ParmVarDecl *FreeFunctionParameter =
DeclCreator.getParamVarDeclsForCurrentField()[0];

Expand All @@ -4212,6 +4261,50 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
return DRE;
}

Expr *createGetAddressOf(Expr *E) {
return UnaryOperator::Create(
SemaSYCLRef.getASTContext(), E, UO_AddrOf,
SemaSYCLRef.getASTContext().getPointerType(E->getType()), VK_PRValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
}

Expr *createDerefOp(Expr *E) {
return UnaryOperator::Create(SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
Comment on lines +4355 to +4358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with surrounding functions. Clang-format might have other ideas though.

Suggested change
return UnaryOperator::Create(SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());
return UnaryOperator::Create(
SemaSYCLRef.getASTContext(), E, UO_Deref,
E->getType()->getPointeeType(), VK_LValue,
OK_Ordinary, SourceLocation(), false,
SemaSYCLRef.SemaRef.CurFPFeatureOverrides());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang-format doesn't agree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format makes bad choices sometimes :)

}

Expr *createReinterpretCastExpr(Expr *E, QualType To) {
return CXXReinterpretCastExpr::Create(
SemaSYCLRef.getASTContext(), To, VK_PRValue, CK_BitCast, E,
/*Path=*/nullptr,
SemaSYCLRef.getASTContext().getTrivialTypeSourceInfo(To),
SourceLocation(), SourceLocation(), SourceRange());
}

Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name suggestion:

Suggested change
Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) {
Expr *createCopyInitExpr(ParmVarDecl *OrigFunctionParameter) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Expr *DRE = createParamReferenceExpr();

assert(OrigFunctionParameter && "no parameter?");

CXXRecordDecl *RD = OrigFunctionParameter->getType()->getAsCXXRecordDecl();
InitializedEntity Entity = InitializedEntity::InitializeParameter(
SemaSYCLRef.getASTContext(), OrigFunctionParameter);

if (RD->hasAttr<SYCLGenerateNewTypeAttr>()) {
DRE = createReinterpretCastExpr(
createGetAddressOf(DRE), SemaSYCLRef.getASTContext().getPointerType(
OrigFunctionParameter->getType()));
DRE = createDerefOp(DRE);
}

ExprResult ArgE = SemaSYCLRef.SemaRef.PerformCopyInitialization(
Entity, SourceLocation(), DRE, false, false);
return ArgE.getAs<Expr>();
}

// For a free function such as:
// void f(int i, int* p, struct Simple S) { ... }
//
Expand Down Expand Up @@ -4281,7 +4374,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
}

bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final {
Expr *PointerRef = createPointerParamReferenceExpr(ParamTy, false);
Expr *PointerRef = createPointerParamReferenceExpr(ParamTy);
ArgExprs.push_back(PointerRef);
return true;
}
Expand All @@ -4299,10 +4392,10 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *,
bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD,
QualType) final {
// TODO
unsupportedFreeFunctionParamType();
Expr *TempCopy = createStructTemporary(PD);
ArgExprs.push_back(TempCopy);
return true;
}

Expand Down Expand Up @@ -4588,8 +4681,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {

bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD,
QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_std_layout);
return true;
}

Expand Down Expand Up @@ -5435,6 +5527,7 @@ void SemaSYCL::MarkDevices() {

void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (isFreeFunction(*this, FD)) {
SyclKernelDecompMarker DecompMarker(*this);
SyclKernelFieldChecker FieldChecker(*this);
SyclKernelUnionChecker UnionChecker(*this);

Expand All @@ -5443,7 +5536,8 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
DiagnosingSYCLKernel = true;

// Check parameters of free function.
Visitor.VisitFunctionParameters(FD, FieldChecker, UnionChecker);
Visitor.VisitFunctionParameters(FD, FieldChecker, UnionChecker,
DecompMarker);

DiagnosingSYCLKernel = false;

Expand Down Expand Up @@ -5889,6 +5983,14 @@ class SYCLFwdDeclEmitter
void VisitPackTemplateArgument(const TemplateArgument &TA) {
VisitTemplateArgs(TA.getPackAsArray());
}

void VisitFunctionProtoType(const FunctionProtoType *T) {
for (const auto Ty : T->getParamTypes())
Visit(Ty.getCanonicalType());
// So far this visitor method is only used for free function kernels whose
// return type is void anyway, so it is not visited. Otherwise, add if
// required.
}
};

class SYCLKernelNameTypePrinter
Expand Down Expand Up @@ -6325,10 +6427,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
continue;

++FreeFunctionCount;
// Generate forward declaration for free function.
O << "\n// Definition of " << K.Name << " as a free function kernel\n";

O << "\n";
O << "// Forward declarations of kernel and its argument types:\n";
FwdDeclEmitter.Visit(K.SyclKernel->getType());
O << "\n";

if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
O << "extern \"C\" ";
std::string ParmList;
Expand Down
Loading