Skip to content

Commit dd8f6cf

Browse files
[SYCL-Upstreaming] Add support for host kernel launch stmt generation (#51)
* Add support for host kernel launch stmt generation This adds generation of a call to sycl_enqueue_kernel_launch function aka "launcher" function. The launcher function can be a memeber of a class or a free function defined at namespace scope. The lookup is performed from SKEP attributed function scope. Because unqualified lookup requires Scope object present and it only exists during parsing stage and already EOLed at the point where templates instantiated, I had to move some parts of SYCLKernelCallStmt generation to earlier stages and now TreeTransform knows how to process SYCLKernelCallStmt. I also had to invent a new expression - UnresolvedSYCLKernelExpr which represents a string containing kernel name of a kernel that doesn't exist yet. This expression is supposed to be transformed to a StringLiteral during template instantiation phase. It should never reach AST consumers like CodeGen of constexpr evaluators. This still requires more testing and FIXME cleanups, but since it evolved into a quite complicated patch I'm pushing it for earlier feedback. * Remove a fixme from SemaSYCL * Do not crash if original body was invalid * Add AST test for skep-attributed member * Fix a warning * Extend codegen test a bit * Find and replace UnresolvedSYCLKernelNameExpr -> UnresolvedSYCLKernelLaunchExpr * Implement the thing * One more find and replace * I don't know how it looks like * Find and replace again * Switch to UnresolvedSYCLKernelEntryPointStmt * Apply suggestions from code review * Remove log.txt * Implement visiting * Add tests * Apply suggestions from code review Co-authored-by: Tom Honermann <[email protected]> * IdExpr -> KernelLaunchIdExpr * Don't rely on compound * UnresolvedSYCLKernelEntryPointStmt -> UnresolvedSYCLKernelCall * Fix warnings * Rename sycl_enqueue_kernel_launch -> sycl_kernel_launch * Apply suggestions from code review Co-authored-by: Tom Honermann <[email protected]> * Remove array decay * Add windows run line to the sema test --------- Co-authored-by: Tom Honermann <[email protected]>
1 parent 70f34c3 commit dd8f6cf

33 files changed

+780
-126
lines changed

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,13 @@ DEF_TRAVERSE_STMT(ParenListExpr, {})
29992999
DEF_TRAVERSE_STMT(SYCLUniqueStableNameExpr, {
30003000
TRY_TO(TraverseTypeLoc(S->getTypeSourceInfo()->getTypeLoc()));
30013001
})
3002+
DEF_TRAVERSE_STMT(UnresolvedSYCLKernelCallStmt, {
3003+
if (getDerived().shouldVisitImplicitCode()) {
3004+
TRY_TO(TraverseStmt(S->getOriginalStmt()));
3005+
TRY_TO(TraverseStmt(S->getKernelLaunchIdExpr()));
3006+
ShouldVisitChildren = false;
3007+
}
3008+
})
30023009
DEF_TRAVERSE_STMT(OpenACCAsteriskSizeExpr, {})
30033010
DEF_TRAVERSE_STMT(PredefinedExpr, {})
30043011
DEF_TRAVERSE_STMT(ShuffleVectorExpr, {})

clang/include/clang/AST/StmtSYCL.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,59 @@ class SYCLKernelCallStmt : public Stmt {
9999
}
100100
};
101101

102+
// UnresolvedSYCLKernelCallStmt represents an invocation of a SYCL kernel in
103+
// a dependent context for which lookup of the sycl_enqueue_kernel_launch
104+
// identifier cannot be performed. These statements are transformed to
105+
// SYCLKernelCallStmt during template instantiation.
106+
class UnresolvedSYCLKernelCallStmt : public Stmt {
107+
friend class ASTStmtReader;
108+
Stmt *OriginalStmt = nullptr;
109+
// KernelLaunchIdExpr stores an UnresolvedLookupExpr or UnresolvedMemberExpr
110+
// corresponding to the SYCL kernel launch function for which a call
111+
// will be synthesized during template instantiation.
112+
Expr *KernelLaunchIdExpr = nullptr;
113+
UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr)
114+
: Stmt(UnresolvedSYCLKernelCallStmtClass), OriginalStmt(CS),
115+
KernelLaunchIdExpr(IdExpr) {}
116+
117+
void setKernelLaunchIdExpr(Expr *IdExpr) { KernelLaunchIdExpr = IdExpr; }
118+
void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; }
119+
120+
public:
121+
static UnresolvedSYCLKernelCallStmt *
122+
Create(const ASTContext &C, CompoundStmt *CS, Expr *IdExpr) {
123+
return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr);
124+
}
125+
126+
static UnresolvedSYCLKernelCallStmt *CreateEmpty(const ASTContext &C) {
127+
return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr);
128+
}
129+
130+
Expr *getKernelLaunchIdExpr() const { return KernelLaunchIdExpr; }
131+
CompoundStmt *getOriginalStmt() { return cast<CompoundStmt>(OriginalStmt); }
132+
const CompoundStmt *getOriginalStmt() const {
133+
return cast<CompoundStmt>(OriginalStmt);
134+
}
135+
136+
SourceLocation getBeginLoc() const LLVM_READONLY {
137+
return getOriginalStmt()->getBeginLoc();
138+
}
139+
140+
SourceLocation getEndLoc() const LLVM_READONLY {
141+
return getOriginalStmt()->getEndLoc();
142+
}
143+
static bool classof(const Stmt *T) {
144+
return T->getStmtClass() == UnresolvedSYCLKernelCallStmtClass;
145+
}
146+
child_range children() {
147+
return child_range(&OriginalStmt, &OriginalStmt + 1);
148+
}
149+
150+
const_child_range children() const {
151+
return const_child_range(&OriginalStmt, &OriginalStmt + 1);
152+
}
153+
};
154+
102155
} // end namespace clang
103156

104157
#endif // LLVM_CLANG_AST_STMTSYCL_H

clang/include/clang/Basic/AttrDocs.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ follows.
554554
namespace sycl {
555555
class handler {
556556
template<typename KernelNameType, typename... Ts>
557-
void sycl_enqueue_kernel_launch(const char *KernelName, Ts...) {
557+
void sycl_kernel_launch(const char *KernelName, Ts...) {
558558
// Call functions appropriate for the desired offload backend
559559
// (OpenCL, CUDA, HIP, Level Zero, etc...) to enqueue kernel invocation.
560560
}
@@ -622,7 +622,7 @@ The offload kernel entry point for a SYCL kernel performs the following tasks:
622622
The ``sycl_kernel_entry_point`` attribute facilitates or automates these tasks
623623
by generating the offload kernel entry point, generating a unique symbol name
624624
for it, synthesizing code for kernel argument decomposition and reconstruction,
625-
and synthesizing a call to a ``sycl_enqueue_kernel_launch`` function template
625+
and synthesizing a call to a ``sycl_kernel_launch`` function template
626626
with the kernel name type, kernel symbol name, and (decomposed) kernel arguments
627627
passed as template or function arguments.
628628

@@ -690,7 +690,7 @@ replaced with synthesized code that looks approximately as follows.
690690

691691
sycl::stream sout = Kernel.sout;
692692
S s = Kernel.s;
693-
sycl_enqueue_kernel_launch<KN>("kernel-symbol-name", sout, s);
693+
sycl_kernel_launch<KN>("kernel-symbol-name", sout, s);
694694

695695
There are a few items worthy of note:
696696

@@ -701,16 +701,16 @@ There are a few items worthy of note:
701701
#. ``kernel-symbol-name`` is substituted for the actual symbol name that would
702702
be generated; these names are implementation details subject to change.
703703

704-
#. Lookup for the ``sycl_enqueue_kernel_launch()`` function template is
704+
#. Lookup for the ``sycl_kernel_launch()`` function template is
705705
performed from the (possibly instantiated) location of the definition of
706706
``kernel_entry_point()``. If overload resolution fails, the program is
707707
ill-formed. If the selected overload is a non-static member function, then
708708
``this`` is passed for the implicit object parameter.
709709

710-
#. Function arguments passed to ``sycl_enqueue_kernel_launch()`` are passed
710+
#. Function arguments passed to ``sycl_kernel_launch()`` are passed
711711
as if by ``std::forward<X>(x)``.
712712

713-
#. The ``sycl_enqueue_kernel_launch()`` function is expected to be provided by
713+
#. The ``sycl_kernel_launch()`` function is expected to be provided by
714714
the SYCL library implementation. It is responsible for scheduling execution
715715
of the generated offload kernel entry point identified by
716716
``kernel-symbol-name`` and copying the (decomposed) kernel arguments to
@@ -721,7 +721,7 @@ attribute to be called for the offload kernel entry point to be emitted. For
721721
inline functions and function templates, any ODR-use will suffice. For other
722722
functions, an ODR-use is not required; the offload kernel entry point will be
723723
emitted if the function is defined. In any case, a call to the function is
724-
required for the synthesized call to ``sycl_enqueue_kernel_launch()`` to occur.
724+
required for the synthesized call to ``sycl_kernel_launch()`` to occur.
725725

726726
Functions declared with the ``sycl_kernel_entry_point`` attribute are not
727727
limited to the simple example shown above. They may have additional template

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13015,6 +13015,15 @@ def err_sycl_entry_point_return_type : Error<
1301513015
def err_sycl_entry_point_deduced_return_type : Error<
1301613016
"the %0 attribute only applies to functions with a non-deduced 'void' return"
1301713017
" type">;
13018+
def err_sycl_host_no_launch_function : Error<
13019+
"unable to find suitable 'sycl_kernel_launch' function for host code "
13020+
"synthesis">;
13021+
def warn_sycl_device_no_host_launch_function : Warning<
13022+
"unable to find suitable 'sycl_kernel_launch' function for host code "
13023+
"synthesis">,
13024+
InGroup<DiagGroup<"sycl-host-launcher">>;
13025+
def note_sycl_host_launch_function : Note<
13026+
"define 'sycl_kernel_launch' function template to fix this problem">;
1301813027

1301913028
def warn_cuda_maxclusterrank_sm_90 : Warning<
1302013029
"maxclusterrank requires sm_90 or higher, CUDA arch provided: %0, ignoring "

clang/include/clang/Basic/StmtNodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def CaseStmt : StmtNode<SwitchCase>;
2525
def DefaultStmt : StmtNode<SwitchCase>;
2626
def CapturedStmt : StmtNode<Stmt>;
2727
def SYCLKernelCallStmt : StmtNode<Stmt>;
28+
def UnresolvedSYCLKernelCallStmt : StmtNode<Stmt>;
2829

2930
// Statements that might produce a value (for example, as the last non-null
3031
// statement in a GNU statement-expression).

clang/include/clang/Sema/ScopeInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ class FunctionScopeInfo {
245245
/// The set of GNU address of label extension "&&label".
246246
llvm::SmallVector<AddrLabelExpr *, 4> AddrLabels;
247247

248+
/// An unresolved identifier lookup expression for an implicit call
249+
/// to a SYCL kernel launch function in a dependent context.
250+
Expr *SYCLKernelLaunchIdExpr = nullptr;
251+
248252
public:
249253
/// Represents a simple identification of a weak object.
250254
///

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ class SemaSYCL : public SemaBase {
6666

6767
void CheckSYCLExternalFunctionDecl(FunctionDecl *FD);
6868
void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
69-
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
69+
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body,
70+
Expr *LaunchIdExpr);
71+
ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType KNT);
72+
StmtResult BuildUnresolvedSYCLKernelCallStmt(CompoundStmt *CS,
73+
Expr *IdExpr);
7074
};
7175

7276
} // namespace clang

clang/include/clang/Serialization/ASTBitCodes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,9 @@ enum StmtCode {
16151615
/// A SYCLKernelCallStmt record.
16161616
STMT_SYCLKERNELCALL,
16171617

1618+
/// A SYCLKernelCallStmt record.
1619+
STMT_UNRESOLVED_SYCL_KERNEL_CALL,
1620+
16181621
/// A GCC-style AsmStmt record.
16191622
STMT_GCCASM,
16201623

clang/lib/AST/ComputeDependence.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "clang/AST/ExprConcepts.h"
1717
#include "clang/AST/ExprObjC.h"
1818
#include "clang/AST/ExprOpenMP.h"
19+
#include "clang/AST/StmtSYCL.h"
1920
#include "clang/Basic/ExceptionSpecificationType.h"
2021
#include "llvm/ADT/ArrayRef.h"
2122

clang/lib/AST/StmtPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,11 @@ void StmtPrinter::VisitSYCLUniqueStableNameExpr(
14281428
OS << ")";
14291429
}
14301430

1431+
void StmtPrinter::VisitUnresolvedSYCLKernelCallStmt(
1432+
UnresolvedSYCLKernelCallStmt *Node) {
1433+
PrintStmt(Node->getOriginalStmt());
1434+
}
1435+
14311436
void StmtPrinter::VisitPredefinedExpr(PredefinedExpr *Node) {
14321437
OS << PredefinedExpr::getIdentKindName(Node->getIdentKind());
14331438
}

0 commit comments

Comments
 (0)