Skip to content

Commit 32113db

Browse files
committed
[SYCL] rework free functions to use a separate entity
1 parent 0858e70 commit 32113db

File tree

2 files changed

+58
-27
lines changed

2 files changed

+58
-27
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6452,25 +6452,55 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
64526452
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
64536453
}
64546454

6455-
static bool insertFreeFunctionDeclaration(const PrintingPolicy &Policy,
6456-
const FunctionDecl *FD,
6457-
const std::string& Args,
6458-
raw_ostream &O) {
6459-
const auto *DC = FD->getDeclContext();
6460-
bool NSInserted{false};
6461-
if (DC) {
6462-
if (isa<NamespaceDecl>(DC)) {
6463-
PrintNamespaces(O, FD);
6464-
NSInserted = true;
6455+
class FreeFunctionPrinter {
6456+
raw_ostream &O;
6457+
const PrintingPolicy &Policy;
6458+
bool NSInserted = false;
6459+
6460+
public:
6461+
FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy)
6462+
: O(O), Policy(Policy) {}
6463+
6464+
/// Emits the function declaration of a free function.
6465+
/// \param FD The function declaration to print.
6466+
/// \param Args The arguments of the function.
6467+
void printFreeFunctionDeclaration(const FunctionDecl *FD,
6468+
const std::string &Args) {
6469+
const DeclContext *DC = FD->getDeclContext();
6470+
if (DC) {
6471+
// if function in namespace, print namespace
6472+
if (isa<NamespaceDecl>(DC)) {
6473+
PrintNamespaces(O, FD);
6474+
// Set flag to print closing braces for namespaces and namespace in shim
6475+
// function
6476+
NSInserted = true;
6477+
}
6478+
O << FD->getReturnType().getAsString() << " ";
6479+
O << FD->getNameAsString() << "(" << Args << ");";
6480+
if (NSInserted) {
6481+
O << "\n";
6482+
PrintNSClosingBraces(O, FD);
6483+
}
6484+
O << "\n";
64656485
}
6466-
O << FD->getReturnType().getAsString() << " ";
6467-
O << FD->getNameAsString() << "(" << Args << ");";
6486+
}
6487+
6488+
/// Emits free function shim function.
6489+
/// \param FD The function declaration to print.
6490+
/// \param ShimCounter The counter for the shim function.
6491+
/// \param ParmList The parameter list of the function.
6492+
void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter,
6493+
const std::string &ParmList) {
6494+
// Generate a shim function that returns the address of the free function.
6495+
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
6496+
O << " return (void (*)(" << ParmList << "))";
6497+
64686498
if (NSInserted) {
6469-
PrintNSClosingBraces(O, FD);
6499+
PrintNamespaces(O, FD, true);
64706500
}
6501+
O << FD->getIdentifier()->getName().data();
64716502
}
6472-
return NSInserted;
6473-
}
6503+
};
64746504

64756505
void SYCLIntegrationHeader::emit(raw_ostream &O) {
64766506
O << "// This is auto-generated SYCL integration header.\n";
@@ -6813,23 +6843,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
68136843
// template arguments that match default template arguments while printing
68146844
// template-ids, even if the source code doesn't reference them.
68156845
Policy.EnforceDefaultTemplateArgs = true;
6816-
bool NSInserted{false};
6846+
FreeFunctionPrinter FFPrinter(O, Policy);
6847+
// bool NSInserted{false};
68176848
if (FTD) {
68186849
FTD->print(O, Policy);
68196850
O << ";\n";
68206851
} else {
6821-
NSInserted = insertFreeFunctionDeclaration(Policy, K.SyclKernel, ParmListWithNames, O);
6822-
O << "\n";
6852+
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
68236853
}
68246854

6825-
// Generate a shim function that returns the address of the free function.
6826-
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
6827-
O << " return (void (*)(" << ParmList << "))";
6828-
if (NSInserted) {
6829-
PrintNamespaces(O, K.SyclKernel, true);
6830-
}
6831-
6832-
O << K.SyclKernel->getIdentifier()->getName().data();
6855+
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
68336856
if (FTD) {
68346857
const TemplateArgumentList *TAL =
68356858
K.SyclKernel->getTemplateSpecializationArgs();
@@ -6908,6 +6931,13 @@ bool SYCLIntegrationHeader::emit(StringRef IntHeaderName) {
69086931
}
69096932
llvm::raw_fd_ostream Out(IntHeaderFD, true /*close in destructor*/);
69106933
emit(Out);
6934+
6935+
int IntHeaderFD1 = 0;
6936+
std::string S{"/tmp/my-files/header.h"};
6937+
llvm::sys::fs::openFileForWrite(S, IntHeaderFD1);
6938+
llvm::raw_fd_ostream Out1(IntHeaderFD1, true /*close in destructor*/);
6939+
emit(Out1);
6940+
69116941
return true;
69126942
}
69136943

clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ foo(Arg1<int> arg) {
8787
// CHECK-NEXT: }
8888

8989
// CHECK: namespace ns {
90-
// CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> );} // namespace ns
90+
// CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> );
91+
// CHECK-NEXT: } // namespace ns
9192
// CHECK: static constexpr auto __sycl_shim1() {
9293
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))ns::simple;
9394
// CHECK-NEXT: }

0 commit comments

Comments
 (0)