Skip to content
Closed
1 change: 0 additions & 1 deletion sycl/include/sycl/kernel_bundle_enums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ enum class source_language : int {
spirv = 1,
sycl = 2,
/* cuda */
sycl_jit = 99 /* temporary, alternative implementation for SYCL */
};

// opencl versions
Expand Down
101 changes: 90 additions & 11 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
#include <sycl/detail/ur.hpp>
#include <sycl/kernel_bundle.hpp>

#include <iostream>
#include <sstream>
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h> // pipe, dup2, read, close
#endif

namespace sycl {
inline namespace _V1 {
namespace detail {
Expand Down Expand Up @@ -1173,17 +1181,21 @@ std::vector<uint8_t> jit_compiler::compileSYCL(
const std::vector<std::string> &UserArgs, std::string *LogPtr,
const std::vector<std::string> &RegisteredKernelNames) {

// TODO: Handle template instantiation.
if (!RegisteredKernelNames.empty()) {
throw sycl::exception(
sycl::errc::build,
"Property `sycl::ext::oneapi::experimental::registered_kernel_names` "
"is not yet supported for the `sycl_jit` source language");
// RegisteredKernelNames may contain template specialization that
// we want to make sure are instantiated. So we just put them in main()
// which ensures they are instantiated.
std::ostringstream ss;
ss << "int main() {\n";
for (const std::string &KernelName : RegisteredKernelNames) {
ss << " (void)" << KernelName << ";\n";
}
ss << " return 0;\n}\n" << std::endl;

std::string FinalSource = SYCLSource + ss.str();

std::string SYCLFileName = Id + ".cpp";
::jit_compiler::InMemoryFile SourceFile{SYCLFileName.c_str(),
SYCLSource.c_str()};
FinalSource.c_str()};

std::vector<::jit_compiler::InMemoryFile> IncludeFilesView;
IncludeFilesView.reserve(IncludePairs.size());
Expand All @@ -1198,14 +1210,81 @@ std::vector<uint8_t> jit_compiler::compileSYCL(
std::back_inserter(UserArgsView),
[](const auto &Arg) { return Arg.c_str(); });

// Redirect stderr to a string stream.
#ifdef _WIN32
HANDLE read_pipe, write_pipe;
SECURITY_ATTRIBUTES sa = {sizeof(SECURITY_ATTRIBUTES), NULL, TRUE};
if (!CreatePipe(&read_pipe, &write_pipe, &sa, 0)) {
throw sycl::exception(sycl::errc::build, "Failed to create pipe");
}

HANDLE saved_stderr = GetStdHandle(STD_ERROR_HANDLE);
HANDLE saved_stdout = GetStdHandle(STD_OUTPUT_HANDLE);
if (!SetStdHandle(STD_ERROR_HANDLE, write_pipe) ||
!SetStdHandle(STD_OUTPUT_HANDLE, write_pipe)) {
throw sycl::exception(sycl::errc::build,
"Failed to redirect stderr/stdout");
}
#else
int pipefd[2];
if (pipe(pipefd) == -1) {
throw sycl::exception(sycl::errc::build, "Failed to create pipe");
}

int saved_stderr = dup(fileno(stderr));
int saved_stdout = dup(fileno(stdout));
if (dup2(pipefd[1], fileno(stderr)) == -1 ||
dup2(pipefd[1], fileno(stdout)) == -1) {
throw sycl::exception(sycl::errc::build,
"Failed to redirect stderr/stdout");
}
close(pipefd[1]);
#endif

std::stringstream error_stream;

// Compile it!
auto Result = CompileSYCLHandle(SourceFile, IncludeFilesView, UserArgsView);

if (Result.failed()) {
throw sycl::exception(sycl::errc::build, Result.getErrorMessage());
// Restore stderr/stdout.
#ifdef _WIN32
SetStdHandle(STD_ERROR_HANDLE, saved_stderr);
SetStdHandle(STD_OUTPUT_HANDLE, saved_stdout);
CloseHandle(write_pipe);

// Read from the pipe
char buffer[1024];
DWORD count;
while (ReadFile(read_pipe, buffer, sizeof(buffer) - 1, &count, NULL) &&
count > 0) {
buffer[count] = '\0';
error_stream << buffer;
}
CloseHandle(read_pipe);
#else
dup2(saved_stderr, fileno(stderr));
dup2(saved_stdout, fileno(stdout));
close(saved_stderr);
close(saved_stdout);

// Read from the pipe
char buffer[1024];
ssize_t count;
while ((count = read(pipefd[0], buffer, sizeof(buffer) - 1)) > 0) {
buffer[count] = '\0';
error_stream << buffer;
}
close(pipefd[0]);
#endif

if (LogPtr != nullptr) {
LogPtr->append(error_stream.str());
}

// TODO: We currently don't have a meaningful build log.
(void)LogPtr;
if (Result.failed()) {
throw sycl::exception(sycl::errc::build,
Result.getErrorMessage() + error_stream.str());
}

const auto &BI = Result.getKernelInfo().BinaryInfo;
assert(BI.Format == ::jit_compiler::BinaryFormat::SPIRV);
Expand Down
12 changes: 3 additions & 9 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,13 +494,7 @@ class kernel_bundle_impl {
return Result;
}
if (Language == syclex::source_language::sycl) {
return syclex::detail::SYCL_to_SPIRV(*SourceStrPtr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
if (Language == syclex::source_language::sycl_jit) {
const auto &SourceStr = std::get<std::string>(this->Source);
return syclex::detail::SYCL_JIT_to_SPIRV(SourceStr, IncludePairs,
return syclex::detail::SYCL_JIT_to_SPIRV(*SourceStrPtr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
Expand Down Expand Up @@ -571,8 +565,7 @@ class kernel_bundle_impl {
std::string adjust_kernel_name(const std::string &Name,
syclex::source_language Lang) {
// Once name demangling support is in, we won't need this.
if (Lang != syclex::source_language::sycl &&
Lang != syclex::source_language::sycl_jit)
if (Lang != syclex::source_language::sycl)
return Name;

bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
Expand All @@ -595,6 +588,7 @@ class kernel_bundle_impl {
"kernel_bundle<bundle_state:ext_oneapi_source>.");

std::string AdjustedName = adjust_kernel_name(Name, Language);

if (!ext_oneapi_has_kernel(Name))
throw sycl::exception(make_error_code(errc::invalid),
"kernel '" + AdjustedName +
Expand Down
Loading
Loading