Skip to content

Commit c1ffcf8

Browse files
[SYCL-Upstreaming] Add support for host kernel launch stmt generation (#51)
* Add support for host kernel launch stmt generation This adds generation of a call to sycl_enqueue_kernel_launch function aka "launcher" function. The launcher function can be a memeber of a class or a free function defined at namespace scope. The lookup is performed from SKEP attributed function scope. Because unqualified lookup requires Scope object present and it only exists during parsing stage and already EOLed at the point where templates instantiated, I had to move some parts of SYCLKernelCallStmt generation to earlier stages and now TreeTransform knows how to process SYCLKernelCallStmt. I also had to invent a new expression - UnresolvedSYCLKernelExpr which represents a string containing kernel name of a kernel that doesn't exist yet. This expression is supposed to be transformed to a StringLiteral during template instantiation phase. It should never reach AST consumers like CodeGen of constexpr evaluators. This still requires more testing and FIXME cleanups, but since it evolved into a quite complicated patch I'm pushing it for earlier feedback. * Remove a fixme from SemaSYCL * Do not crash if original body was invalid * Add AST test for skep-attributed member * Fix a warning * Extend codegen test a bit * Find and replace UnresolvedSYCLKernelNameExpr -> UnresolvedSYCLKernelLaunchExpr * Implement the thing * One more find and replace * I don't know how it looks like * Find and replace again * Switch to UnresolvedSYCLKernelEntryPointStmt * Apply suggestions from code review * Remove log.txt * Implement visiting * Add tests * Apply suggestions from code review Co-authored-by: Tom Honermann <[email protected]> * IdExpr -> KernelLaunchIdExpr * Don't rely on compound * UnresolvedSYCLKernelEntryPointStmt -> UnresolvedSYCLKernelCall * Fix warnings * Rename sycl_enqueue_kernel_launch -> sycl_kernel_launch * Apply suggestions from code review Co-authored-by: Tom Honermann <[email protected]> * Remove array decay * Add windows run line to the sema test --------- Co-authored-by: Tom Honermann <[email protected]>
1 parent 2a1c23d commit c1ffcf8

33 files changed

+780
-126
lines changed

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,6 +2990,13 @@ DEF_TRAVERSE_STMT(ParenListExpr, {})
29902990
DEF_TRAVERSE_STMT(SYCLUniqueStableNameExpr, {
29912991
TRY_TO(TraverseTypeLoc(S->getTypeSourceInfo()->getTypeLoc()));
29922992
})
2993+
DEF_TRAVERSE_STMT(UnresolvedSYCLKernelCallStmt, {
2994+
if (getDerived().shouldVisitImplicitCode()) {
2995+
TRY_TO(TraverseStmt(S->getOriginalStmt()));
2996+
TRY_TO(TraverseStmt(S->getKernelLaunchIdExpr()));
2997+
ShouldVisitChildren = false;
2998+
}
2999+
})
29933000
DEF_TRAVERSE_STMT(OpenACCAsteriskSizeExpr, {})
29943001
DEF_TRAVERSE_STMT(PredefinedExpr, {})
29953002
DEF_TRAVERSE_STMT(ShuffleVectorExpr, {})

clang/include/clang/AST/StmtSYCL.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,59 @@ class SYCLKernelCallStmt : public Stmt {
9999
}
100100
};
101101

102+
// UnresolvedSYCLKernelCallStmt represents an invocation of a SYCL kernel in
103+
// a dependent context for which lookup of the sycl_enqueue_kernel_launch
104+
// identifier cannot be performed. These statements are transformed to
105+
// SYCLKernelCallStmt during template instantiation.
106+
class UnresolvedSYCLKernelCallStmt : public Stmt {
107+
friend class ASTStmtReader;
108+
Stmt *OriginalStmt = nullptr;
109+
// KernelLaunchIdExpr stores an UnresolvedLookupExpr or UnresolvedMemberExpr
110+
// corresponding to the SYCL kernel launch function for which a call
111+
// will be synthesized during template instantiation.
112+
Expr *KernelLaunchIdExpr = nullptr;
113+
UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr)
114+
: Stmt(UnresolvedSYCLKernelCallStmtClass), OriginalStmt(CS),
115+
KernelLaunchIdExpr(IdExpr) {}
116+
117+
void setKernelLaunchIdExpr(Expr *IdExpr) { KernelLaunchIdExpr = IdExpr; }
118+
void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; }
119+
120+
public:
121+
static UnresolvedSYCLKernelCallStmt *
122+
Create(const ASTContext &C, CompoundStmt *CS, Expr *IdExpr) {
123+
return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr);
124+
}
125+
126+
static UnresolvedSYCLKernelCallStmt *CreateEmpty(const ASTContext &C) {
127+
return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr);
128+
}
129+
130+
Expr *getKernelLaunchIdExpr() const { return KernelLaunchIdExpr; }
131+
CompoundStmt *getOriginalStmt() { return cast<CompoundStmt>(OriginalStmt); }
132+
const CompoundStmt *getOriginalStmt() const {
133+
return cast<CompoundStmt>(OriginalStmt);
134+
}
135+
136+
SourceLocation getBeginLoc() const LLVM_READONLY {
137+
return getOriginalStmt()->getBeginLoc();
138+
}
139+
140+
SourceLocation getEndLoc() const LLVM_READONLY {
141+
return getOriginalStmt()->getEndLoc();
142+
}
143+
static bool classof(const Stmt *T) {
144+
return T->getStmtClass() == UnresolvedSYCLKernelCallStmtClass;
145+
}
146+
child_range children() {
147+
return child_range(&OriginalStmt, &OriginalStmt + 1);
148+
}
149+
150+
const_child_range children() const {
151+
return const_child_range(&OriginalStmt, &OriginalStmt + 1);
152+
}
153+
};
154+
102155
} // end namespace clang
103156

104157
#endif // LLVM_CLANG_AST_STMTSYCL_H

clang/include/clang/Basic/AttrDocs.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ follows.
566566
namespace sycl {
567567
class handler {
568568
template<typename KernelNameType, typename... Ts>
569-
void sycl_enqueue_kernel_launch(const char *KernelName, Ts...) {
569+
void sycl_kernel_launch(const char *KernelName, Ts...) {
570570
// Call functions appropriate for the desired offload backend
571571
// (OpenCL, CUDA, HIP, Level Zero, etc...) to enqueue kernel invocation.
572572
}
@@ -634,7 +634,7 @@ The offload kernel entry point for a SYCL kernel performs the following tasks:
634634
The ``sycl_kernel_entry_point`` attribute facilitates or automates these tasks
635635
by generating the offload kernel entry point, generating a unique symbol name
636636
for it, synthesizing code for kernel argument decomposition and reconstruction,
637-
and synthesizing a call to a ``sycl_enqueue_kernel_launch`` function template
637+
and synthesizing a call to a ``sycl_kernel_launch`` function template
638638
with the kernel name type, kernel symbol name, and (decomposed) kernel arguments
639639
passed as template or function arguments.
640640

@@ -702,7 +702,7 @@ replaced with synthesized code that looks approximately as follows.
702702

703703
sycl::stream sout = Kernel.sout;
704704
S s = Kernel.s;
705-
sycl_enqueue_kernel_launch<KN>("kernel-symbol-name", sout, s);
705+
sycl_kernel_launch<KN>("kernel-symbol-name", sout, s);
706706

707707
There are a few items worthy of note:
708708

@@ -713,16 +713,16 @@ There are a few items worthy of note:
713713
#. ``kernel-symbol-name`` is substituted for the actual symbol name that would
714714
be generated; these names are implementation details subject to change.
715715

716-
#. Lookup for the ``sycl_enqueue_kernel_launch()`` function template is
716+
#. Lookup for the ``sycl_kernel_launch()`` function template is
717717
performed from the (possibly instantiated) location of the definition of
718718
``kernel_entry_point()``. If overload resolution fails, the program is
719719
ill-formed. If the selected overload is a non-static member function, then
720720
``this`` is passed for the implicit object parameter.
721721

722-
#. Function arguments passed to ``sycl_enqueue_kernel_launch()`` are passed
722+
#. Function arguments passed to ``sycl_kernel_launch()`` are passed
723723
as if by ``std::forward<X>(x)``.
724724

725-
#. The ``sycl_enqueue_kernel_launch()`` function is expected to be provided by
725+
#. The ``sycl_kernel_launch()`` function is expected to be provided by
726726
the SYCL library implementation. It is responsible for scheduling execution
727727
of the generated offload kernel entry point identified by
728728
``kernel-symbol-name`` and copying the (decomposed) kernel arguments to
@@ -733,7 +733,7 @@ attribute to be called for the offload kernel entry point to be emitted. For
733733
inline functions and function templates, any ODR-use will suffice. For other
734734
functions, an ODR-use is not required; the offload kernel entry point will be
735735
emitted if the function is defined. In any case, a call to the function is
736-
required for the synthesized call to ``sycl_enqueue_kernel_launch()`` to occur.
736+
required for the synthesized call to ``sycl_kernel_launch()`` to occur.
737737

738738
Functions declared with the ``sycl_kernel_entry_point`` attribute are not
739739
limited to the simple example shown above. They may have additional template

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13088,6 +13088,15 @@ def err_sycl_entry_point_return_type : Error<
1308813088
def err_sycl_entry_point_deduced_return_type : Error<
1308913089
"the %0 attribute only applies to functions with a non-deduced 'void' return"
1309013090
" type">;
13091+
def err_sycl_host_no_launch_function : Error<
13092+
"unable to find suitable 'sycl_kernel_launch' function for host code "
13093+
"synthesis">;
13094+
def warn_sycl_device_no_host_launch_function : Warning<
13095+
"unable to find suitable 'sycl_kernel_launch' function for host code "
13096+
"synthesis">,
13097+
InGroup<DiagGroup<"sycl-host-launcher">>;
13098+
def note_sycl_host_launch_function : Note<
13099+
"define 'sycl_kernel_launch' function template to fix this problem">;
1309113100

1309213101
def warn_cuda_maxclusterrank_sm_90 : Warning<
1309313102
"maxclusterrank requires sm_90 or higher, CUDA arch provided: %0, ignoring "

clang/include/clang/Basic/StmtNodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def CaseStmt : StmtNode<SwitchCase>;
2323
def DefaultStmt : StmtNode<SwitchCase>;
2424
def CapturedStmt : StmtNode<Stmt>;
2525
def SYCLKernelCallStmt : StmtNode<Stmt>;
26+
def UnresolvedSYCLKernelCallStmt : StmtNode<Stmt>;
2627

2728
// Break/continue.
2829
def LoopControlStmt : StmtNode<Stmt, 1>;

clang/include/clang/Sema/ScopeInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ class FunctionScopeInfo {
245245
/// The set of GNU address of label extension "&&label".
246246
llvm::SmallVector<AddrLabelExpr *, 4> AddrLabels;
247247

248+
/// An unresolved identifier lookup expression for an implicit call
249+
/// to a SYCL kernel launch function in a dependent context.
250+
Expr *SYCLKernelLaunchIdExpr = nullptr;
251+
248252
public:
249253
/// Represents a simple identification of a weak object.
250254
///

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ class SemaSYCL : public SemaBase {
6666

6767
void CheckSYCLExternalFunctionDecl(FunctionDecl *FD);
6868
void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
69-
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
69+
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body,
70+
Expr *LaunchIdExpr);
71+
ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType KNT);
72+
StmtResult BuildUnresolvedSYCLKernelCallStmt(CompoundStmt *CS,
73+
Expr *IdExpr);
7074
};
7175

7276
} // namespace clang

clang/include/clang/Serialization/ASTBitCodes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,9 @@ enum StmtCode {
16151615
/// A SYCLKernelCallStmt record.
16161616
STMT_SYCLKERNELCALL,
16171617

1618+
/// A SYCLKernelCallStmt record.
1619+
STMT_UNRESOLVED_SYCL_KERNEL_CALL,
1620+
16181621
/// A GCC-style AsmStmt record.
16191622
STMT_GCCASM,
16201623

clang/lib/AST/ComputeDependence.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "clang/AST/ExprConcepts.h"
1717
#include "clang/AST/ExprObjC.h"
1818
#include "clang/AST/ExprOpenMP.h"
19+
#include "clang/AST/StmtSYCL.h"
1920
#include "clang/Basic/ExceptionSpecificationType.h"
2021
#include "llvm/ADT/ArrayRef.h"
2122

clang/lib/AST/StmtPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,11 @@ void StmtPrinter::VisitSYCLUniqueStableNameExpr(
14421442
OS << ")";
14431443
}
14441444

1445+
void StmtPrinter::VisitUnresolvedSYCLKernelCallStmt(
1446+
UnresolvedSYCLKernelCallStmt *Node) {
1447+
PrintStmt(Node->getOriginalStmt());
1448+
}
1449+
14451450
void StmtPrinter::VisitPredefinedExpr(PredefinedExpr *Node) {
14461451
OS << PredefinedExpr::getIdentKindName(Node->getIdentKind());
14471452
}

0 commit comments

Comments
 (0)