diff --git a/clang/docs/ClangOffloadPackager.rst b/clang/docs/ClangOffloadPackager.rst index 2b985e260e302..8c48b6396af8f 100644 --- a/clang/docs/ClangOffloadPackager.rst +++ b/clang/docs/ClangOffloadPackager.rst @@ -97,7 +97,8 @@ the following values for the :ref:`offload kind` and the +---------------+-------+---------------------------------------+ | IMG_PTX | 0x05 | The image is a CUDA PTX file | +---------------+-------+---------------------------------------+ - + | IMG_SPV | 0x06 | The image is a SPIR-V binary file | + +---------------+-------+---------------------------------------+ .. table:: Offload Kind :name: table-offload_kind @@ -112,6 +113,8 @@ the following values for the :ref:`offload kind` and the +------------+-------+---------------------------------------+ | OFK_HIP | 0x03 | The producer was HIP | +------------+-------+---------------------------------------+ + | OFK_SYCL | 0x04 | The producer was SYCL | + +------------+-------+---------------------------------------+ The flags are used to signify certain conditions, such as the presence of debugging information or whether or not LTO was used. The string entry table is diff --git a/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp b/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp index ab718c5a87c40..28f89cea3d44f 100644 --- a/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp +++ b/clang/tools/clang-sycl-linker/ClangSYCLLinker.cpp @@ -22,6 +22,8 @@ #include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/LTO/LTO.h" #include "llvm/Linker/Linker.h" @@ -50,6 +52,9 @@ #include "llvm/Support/TimeProfiler.h" #include "llvm/Support/WithColor.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" +#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" using namespace llvm; using namespace llvm::opt; @@ -77,9 +82,14 @@ static void printVersion(raw_ostream &OS) { /// The value of `argv[0]` when run. static const char *Executable; +/// Mutex lock to protect writes to shared TempFiles in parallel. +static std::mutex TempFilesMutex; + /// Temporary files to be cleaned up. static SmallVector> TempFiles; +using OffloadingImage = OffloadBinary::OffloadingImage; + namespace { // Must not overlap with llvm::opt::DriverFlag. enum LinkerFlags { LinkerOnlyOption = (1 << 4) }; @@ -143,6 +153,59 @@ Expected createTempFile(const ArgList &Args, const Twine &Prefix, return TempFiles.back(); } +/// Get a temporary filename suitable for output. +Expected createOutputFile(const Twine &Prefix, StringRef Extension) { + std::scoped_lock Lock(TempFilesMutex); + SmallString<128> OutputFile; + if (SaveTemps) { + (Prefix + "." + Extension).toNullTerminatedStringRef(OutputFile); + } else { + if (std::error_code EC = + sys::fs::createTemporaryFile(Prefix, Extension, OutputFile)) + return createFileError(OutputFile, EC); + } + + TempFiles.emplace_back(std::move(OutputFile)); + return TempFiles.back(); +} + +Expected writeOffloadFile(const OffloadFile &File) { + const OffloadBinary &Binary = *File.getBinary(); + + StringRef Prefix = + sys::path::stem(Binary.getMemoryBufferRef().getBufferIdentifier()); + SmallString<128> Filename; + (Prefix + "-" + Binary.getTriple() + "-" + Binary.getArch()) + .toVector(Filename); + llvm::replace(Filename, ':', '-'); + auto TempFileOrErr = createOutputFile(Filename, "o"); + if (!TempFileOrErr) + return TempFileOrErr.takeError(); + + Expected> OutputOrErr = + FileOutputBuffer::create(*TempFileOrErr, Binary.getImage().size()); + if (!OutputOrErr) + return OutputOrErr.takeError(); + std::unique_ptr Output = std::move(*OutputOrErr); + llvm::copy(Binary.getImage(), Output->getBufferStart()); + if (Error E = Output->commit()) + return std::move(E); + + return *TempFileOrErr; +} + +static Error writeFile(StringRef Filename, StringRef Data) { + Expected> OutputOrErr = + FileOutputBuffer::create(Filename, Data.size()); + if (!OutputOrErr) + return OutputOrErr.takeError(); + std::unique_ptr Output = std::move(*OutputOrErr); + llvm::copy(Data, Output->getBufferStart()); + if (Error E = Output->commit()) + return E; + return Error::success(); +} + Expected> getInput(const ArgList &Args) { // Collect all input bitcode files to be passed to the device linking stage. SmallVector BitcodeFiles; @@ -274,12 +337,79 @@ Expected linkDeviceCode(ArrayRef InputFiles, return *BitcodeOutput; } +void cleanupModule(Module &M) { + ModuleAnalysisManager MAM; + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + ModulePassManager MPM; + MPM.addPass(GlobalDCEPass()); // Delete unreachable globals. + MPM.run(M, MAM); +} + +void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { + int FD = -1; + if (std::error_code EC = sys::fs::openFileForWrite(Path, FD)) { + errs() << formatv("error opening file: {0}, error: {1}", Path, EC.message()) + << '\n'; + exit(1); + } + + raw_fd_ostream OS(FD, /*ShouldClose*/ true); + if (OutputAssembly) + M.print(OS, /*AssemblyAnnotationWriter*/ nullptr); + else + WriteBitcodeToFile(M, OS); +} + +Expected> runSYCLSplitModule(std::unique_ptr M, const ArgList &Args) { + SmallVector SplitModules; + if (Error Err = M->materializeAll()) + return std::move(Err); + auto PostSYCLSplitCallback = [&](std::unique_ptr MPart, + std::string Symbols) { + if (verifyModule(*MPart)) { + errs() << "Broken Module!\n"; + exit(1); + } + if (Error Err = MPart->materializeAll()) { + errs() << "Broken Module!\n"; + exit(1); + } + // TODO: DCE is a crucial pass in a SYCL post-link pipeline. + // At the moment, LIT checking can't be perfomed without DCE. + cleanupModule(*MPart); + size_t ID = SplitModules.size(); + StringRef ModuleSuffix = ".bc"; + std::string ModulePath = + (Twine(OutputFile) + "_post_link_" + Twine(ID) + ModuleSuffix).str(); + writeModuleToFile(*MPart, ModulePath, /* OutputAssembly */ false); + SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); + }; + + StringRef Mode = Args.getLastArgValue(OPT_sycl_split_mode_EQ); + auto SYCLSplitMode = StringSwitch(Mode) + .Case("per_source", IRSplitMode::IRSM_PER_TU) + .Case("per_kernel", IRSplitMode::IRSM_PER_KERNEL) + .Case("none", IRSplitMode::IRSM_NONE) + .Default(IRSplitMode::IRSM_NONE); + SYCLSplitModule(std::move(M), SYCLSplitMode, PostSYCLSplitCallback); + + if (Verbose) { + std::string OutputFiles; + for (size_t I = 0, E = SplitModules.size(); I != E; ++I) { + OutputFiles.append(SplitModules[I].ModuleFilePath); + OutputFiles.append("\n"); + } + errs() << formatv("sycl-module-split: outputs:\n{0}\n", OutputFiles); + } + return SplitModules; +} + /// Run LLVM to SPIR-V translation. /// Converts 'File' from LLVM bitcode to SPIR-V format using SPIR-V backend. /// 'Args' encompasses all arguments required for linking device code and will /// be parsed to generate options required to be passed into the backend. -static Expected runSPIRVCodeGen(StringRef File, const ArgList &Args, - LLVMContext &C) { +static Error runSPIRVCodeGen(StringRef File, const ArgList &Args, + StringRef SPVFile, LLVMContext &C) { llvm::TimeTraceScope TimeScope("SPIR-V code generation"); // Parse input module. @@ -288,6 +418,9 @@ static Expected runSPIRVCodeGen(StringRef File, const ArgList &Args, if (!M) return createStringError(Err.getMessage()); + if (Error Err = M->materializeAll()) + return std::move(Err); + Triple TargetTriple(Args.getLastArgValue(OPT_triple_EQ)); M->setTargetTriple(TargetTriple); @@ -313,7 +446,7 @@ static Expected runSPIRVCodeGen(StringRef File, const ArgList &Args, // Open output file for writing. int FD = -1; - if (std::error_code EC = sys::fs::openFileForWrite(OutputFile, FD)) + if (std::error_code EC = sys::fs::openFileForWrite(SPVFile, FD)) return errorCodeToError(EC); auto OS = std::make_unique(FD, true); @@ -328,9 +461,9 @@ static Expected runSPIRVCodeGen(StringRef File, const ArgList &Args, if (Verbose) errs() << formatv("SPIR-V Backend: input: {0}, output: {1}\n", File, - OutputFile); + SPVFile); - return OutputFile; + return Error::success(); } /// Performs the following steps: @@ -339,6 +472,7 @@ static Expected runSPIRVCodeGen(StringRef File, const ArgList &Args, Error runSYCLLink(ArrayRef Files, const ArgList &Args) { llvm::TimeTraceScope TimeScope("SYCL device linking"); + std::mutex ImageMtx; LLVMContext C; // Link all input bitcode files and SYCL device library files, if any. @@ -346,10 +480,77 @@ Error runSYCLLink(ArrayRef Files, const ArgList &Args) { if (!LinkedFile) reportError(LinkedFile.takeError()); + auto LinkedModule = getBitcodeModule(*LinkedFile, C); + if (!LinkedModule) + return LinkedModule.takeError(); + // sycl-post-link step + auto SplitModules = runSYCLSplitModule(std::move(*LinkedModule), Args); + if (!SplitModules) + reportError(SplitModules.takeError()); + // SPIR-V code generation step. - auto SPVFile = runSPIRVCodeGen(*LinkedFile, Args, C); - if (!SPVFile) - return SPVFile.takeError(); + for (size_t I = 0, E = (*SplitModules).size(); I != E; ++I) { + std::string SPVFile(OutputFile); + SPVFile.append(utostr(I)); + auto Err = runSPIRVCodeGen((*SplitModules)[I].ModuleFilePath, Args, SPVFile, C); + if (Err) + return std::move(Err); + (*SplitModules)[I].ModuleFilePath = SPVFile; + } + + SmallVector BinaryData; + raw_svector_ostream OS(BinaryData); + for (size_t I = 0, E = (*SplitModules).size(); I != E; ++I) { + auto File = (*SplitModules)[I].ModuleFilePath; + llvm::ErrorOr> FileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(File); + if (std::error_code EC = FileOrErr.getError()) { + if (DryRun) + FileOrErr = MemoryBuffer::getMemBuffer(""); + else + return createFileError(File, EC); + } + + std::scoped_lock Guard(ImageMtx); + OffloadingImage TheImage{}; + TheImage.TheImageKind = IMG_Object; + TheImage.TheOffloadKind = OFK_SYCL; + TheImage.StringData["triple"] = + Args.MakeArgString(Args.getLastArgValue(OPT_triple_EQ)); + TheImage.StringData["arch"] = + Args.MakeArgString(Args.getLastArgValue(OPT_arch_EQ)); + TheImage.Image = std::move(*FileOrErr); + + llvm::SmallString<0> Buffer = OffloadBinary::write(TheImage); + if (Buffer.size() % OffloadBinary::getAlignment() != 0) + return createStringError(inconvertibleErrorCode(), + "Offload binary has invalid size alignment"); + OS << Buffer; + } + if (Error E = writeFile(OutputFile, + StringRef(BinaryData.begin(), BinaryData.size()))) + return E; + + { + ErrorOr> BufferOrErr = + MemoryBuffer::getFileOrSTDIN(OutputFile); + if (std::error_code EC = BufferOrErr.getError()) + return createFileError(OutputFile, EC); + + MemoryBufferRef Buffer = **BufferOrErr; + SmallVector Binaries; + if (Error Err = extractOffloadBinaries(Buffer, Binaries)) + return std::move(Err); + + unsigned I = 1; + for (auto &OffloadFile : Binaries) { + auto FileNameOrErr = writeOffloadFile(OffloadFile); + if (!FileNameOrErr) + return FileNameOrErr.takeError(); + llvm::errs() << I++ << ". " << *FileNameOrErr << "\n"; + } + } + return Error::success(); } diff --git a/clang/tools/clang-sycl-linker/SYCLLinkOpts.td b/clang/tools/clang-sycl-linker/SYCLLinkOpts.td index 1006784973b87..283c97df9eb0e 100644 --- a/clang/tools/clang-sycl-linker/SYCLLinkOpts.td +++ b/clang/tools/clang-sycl-linker/SYCLLinkOpts.td @@ -39,6 +39,9 @@ def save_temps : Flag<["--", "-"], "save-temps">, def dry_run : Flag<["--", "-"], "dry-run">, Flags<[LinkerOnlyOption]>, HelpText<"Print generated commands without running.">; +def sycl_split_mode_EQ : Joined<["--", "-"], "sycl-split-mode=">, + HelpText<"Mode of splitting performed by SYCL splitting algorithm. Options are 'per_source', 'per_kernel' and 'none'.">; + def spirv_dump_device_code_EQ : Joined<["--", "-"], "spirv-dump-device-code=">, Flags<[LinkerOnlyOption]>, HelpText<"Path to the folder where the tool dumps SPIR-V device code. Other formats aren't dumped.">; diff --git a/llvm/include/llvm/Object/OffloadBinary.h b/llvm/include/llvm/Object/OffloadBinary.h index c02aec8d956ed..b9344944eae3c 100644 --- a/llvm/include/llvm/Object/OffloadBinary.h +++ b/llvm/include/llvm/Object/OffloadBinary.h @@ -35,6 +35,7 @@ enum OffloadKind : uint16_t { OFK_OpenMP, OFK_Cuda, OFK_HIP, + OFK_SYCL, OFK_LAST, }; @@ -46,6 +47,7 @@ enum ImageKind : uint16_t { IMG_Cubin, IMG_Fatbinary, IMG_PTX, + IMG_SPV, IMG_LAST, }; diff --git a/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h b/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h new file mode 100644 index 0000000000000..a3425d19b9c4b --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h @@ -0,0 +1,64 @@ +//===-------- SYCLSplitModule.h - module split ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Functionality to split a module into callgraphs. A callgraph here is a set +// of entry points with all functions reachable from them via a call. The result +// of the split is new modules containing corresponding callgraph. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#define LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" + +#include +#include +#include + +namespace llvm { + +class Module; + +enum class IRSplitMode { + IRSM_PER_TU, // one module per translation unit + IRSM_PER_KERNEL, // one module per kernel + IRSM_NONE // no splitting +}; + +/// \returns IRSplitMode value if \p S is recognized. Otherwise, std::nullopt is +/// returned. +std::optional convertStringToSplitMode(StringRef S); + +/// The structure represents a split LLVM Module accompanied by additional +/// information. Split Modules are being stored at disk due to the high RAM +/// consumption during the whole splitting process. +struct ModuleAndSYCLMetadata { + std::string ModuleFilePath; + std::string Symbols; + + ModuleAndSYCLMetadata() = default; + ModuleAndSYCLMetadata(const ModuleAndSYCLMetadata &) = default; + ModuleAndSYCLMetadata &operator=(const ModuleAndSYCLMetadata &) = default; + ModuleAndSYCLMetadata(ModuleAndSYCLMetadata &&) = default; + ModuleAndSYCLMetadata &operator=(ModuleAndSYCLMetadata &&) = default; + + ModuleAndSYCLMetadata(std::string_view File, std::string Symbols) + : ModuleFilePath(File), Symbols(std::move(Symbols)) {} +}; + +using PostSYCLSplitCallbackType = + function_ref Part, std::string Symbols)>; + +/// Splits the given module \p M according to the given \p Settings. +/// Every split image is being passed to \p Callback. +void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, + PostSYCLSplitCallbackType Callback); + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H diff --git a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h new file mode 100644 index 0000000000000..75459eed6ac0f --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h @@ -0,0 +1,26 @@ +//===------------ SYCLUtils.h - SYCL utility functions --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Utility functions for SYCL. +//===----------------------------------------------------------------------===// +#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H +#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H + +#include +#include + +namespace llvm { + +class raw_ostream; + +using SYCLStringTable = SmallVector>>; + +void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS); + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 78cad0d253be8..0ba46bdadea8d 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -83,6 +83,8 @@ add_llvm_component_library(LLVMTransformUtils SizeOpts.cpp SplitModule.cpp StripNonLineTableDebugInfo.cpp + SYCLSplitModule.cpp + SYCLUtils.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp UnifyLoopExits.cpp diff --git a/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp b/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp new file mode 100644 index 0000000000000..18eca4237c8ae --- /dev/null +++ b/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp @@ -0,0 +1,401 @@ +//===-------- SYCLSplitModule.cpp - Split a module into call graphs -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// See comments in the header. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" + +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "sycl-split-module" + +static bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL; +} + +static bool isEntryPoint(const Function &F) { + // Skip declarations, if any: they should not be included into a vector of + // entry points groups or otherwise we will end up with incorrectly generated + // list of symbols. + if (F.isDeclaration()) + return false; + + // Kernels are always considered to be entry points + return isKernel(F); +} + +namespace { + +// A vector that contains all entry point functions in a split module. +using EntryPointSet = SetVector; + +/// Represents a named group entry points. +struct EntryPointGroup { + std::string GroupName; + EntryPointSet Functions; + + EntryPointGroup() = default; + EntryPointGroup(const EntryPointGroup &) = default; + EntryPointGroup &operator=(const EntryPointGroup &) = default; + EntryPointGroup(EntryPointGroup &&) = default; + EntryPointGroup &operator=(EntryPointGroup &&) = default; + + EntryPointGroup(StringRef GroupName, + EntryPointSet Functions = EntryPointSet()) + : GroupName(GroupName), Functions(std::move(Functions)) {} + + void clear() { + GroupName.clear(); + Functions.clear(); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + constexpr size_t INDENT = 4; + dbgs().indent(INDENT) << "ENTRY POINTS" + << " " << GroupName << " {\n"; + for (const Function *F : Functions) + dbgs().indent(INDENT) << " " << F->getName() << "\n"; + + dbgs().indent(INDENT) << "}\n"; + } +#endif +}; + +/// Annotates an llvm::Module with information necessary to perform and track +/// the result of device code (llvm::Module instances) splitting: +/// - entry points group from the module. +class ModuleDesc { + std::unique_ptr M; + EntryPointGroup EntryPoints; + +public: + ModuleDesc() = delete; + ModuleDesc(const ModuleDesc &) = delete; + ModuleDesc &operator=(const ModuleDesc &) = delete; + ModuleDesc(ModuleDesc &&) = default; + ModuleDesc &operator=(ModuleDesc &&) = default; + + ModuleDesc(std::unique_ptr M, + EntryPointGroup EntryPoints = EntryPointGroup()) + : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { + assert(this->M && "Module should be non-null"); + } + + Module &getModule() { return *M; } + const Module &getModule() const { return *M; } + + std::unique_ptr releaseModule() { + EntryPoints.clear(); + return std::move(M); + } + + std::string makeSymbolTable() const { + SmallString<0> Data; + raw_svector_ostream OS(Data); + for (const Function *F : EntryPoints.Functions) + OS << F->getName() << '\n'; + + return std::string(OS.str()); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; + EntryPoints.dump(); + dbgs() << "}\n"; + } +#endif +}; + +// Represents "dependency" or "use" graph of global objects (functions and +// global variables) in a module. It is used during device code split to +// understand which global variables and functions (other than entry points) +// should be included into a split module. +// +// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent +// the fact that if "A" is included into a module, then "B" should be included +// as well. +// +// Examples of dependencies which are represented in this graph: +// - Function FA calls function FB +// - Function FA uses global variable GA +// - Global variable GA references (initialized with) function FB +// - Function FA stores address of a function FB somewhere +// +// The following cases are treated as dependencies between global objects: +// 1. Global object A is used within by a global object B in any way (store, +// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the +// graph; +// 2. function A performs an indirect call of a function with signature S and +// there is a function B with signature S. "A" -> "B" edge will be added to +// the graph; +class DependencyGraph { +public: + using GlobalSet = SmallPtrSet; + + DependencyGraph(const Module &M) { + // Group functions by their signature to handle case (2) described above + DenseMap + FuncTypeToFuncsMap; + for (const auto &F : M.functions()) { + // Kernels can't be called (either directly or indirectly) in SYCL + if (isKernel(F)) + continue; + + FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); + } + + for (const auto &F : M.functions()) { + // case (1), see comment above the class definition + for (const Value *U : F.users()) + addUserToGraphRecursively(cast(U), &F); + + // case (2), see comment above the class definition + for (const auto &I : instructions(F)) { + const auto *CI = dyn_cast(&I); + if (!CI || !CI->isIndirectCall()) // Direct calls were handled above + continue; + + const FunctionType *Signature = CI->getFunctionType(); + const auto &PotentialCallees = FuncTypeToFuncsMap[Signature]; + Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end()); + } + } + + // And every global variable (but their handling is a bit simpler) + for (const auto &GV : M.globals()) + for (const Value *U : GV.users()) + addUserToGraphRecursively(cast(U), &GV); + } + + iterator_range + dependencies(const GlobalValue *Val) const { + auto It = Graph.find(Val); + return (It == Graph.end()) + ? make_range(EmptySet.begin(), EmptySet.end()) + : make_range(It->second.begin(), It->second.end()); + } + +private: + void addUserToGraphRecursively(const User *Root, const GlobalValue *V) { + SmallVector WorkList; + WorkList.push_back(Root); + + while (!WorkList.empty()) { + const User *U = WorkList.pop_back_val(); + if (const auto *I = dyn_cast(U)) { + const auto *UFunc = I->getFunction(); + Graph[UFunc].insert(V); + } else if (isa(U)) { + if (const auto *GV = dyn_cast(U)) + Graph[GV].insert(V); + // This could be a global variable or some constant expression (like + // bitcast or gep). We trace users of this constant further to reach + // global objects they are used by and add them to the graph. + for (const auto *UU : U->users()) + WorkList.push_back(UU); + } else + llvm_unreachable("Unhandled type of function user"); + } + } + + DenseMap Graph; + SmallPtrSet EmptySet; +}; + +void collectFunctionsAndGlobalVariablesToExtract( + SetVector &GVs, const Module &M, + const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { + // We start with module entry points + for (const auto *F : ModuleEntryPoints.Functions) + GVs.insert(F); + + // Non-discardable global variables are also include into the initial set + for (const auto &GV : M.globals()) + if (!GV.isDiscardableIfUnused()) + GVs.insert(&GV); + + // GVs has SetVector type. This type inserts a value only if it is not yet + // present there. So, recursion is not expected here. + size_t Idx = 0; + while (Idx < GVs.size()) { + const GlobalValue *Obj = GVs[Idx++]; + + for (const GlobalValue *Dep : DG.dependencies(Obj)) { + if (const auto *Func = dyn_cast(Dep)) { + if (!Func->isDeclaration()) + GVs.insert(Func); + } else + GVs.insert(Dep); // Global variables are added unconditionally + } + } +} + +ModuleDesc extractSubModule(const Module &M, + const SetVector &GVs, + EntryPointGroup ModuleEntryPoints) { + // For each group of entry points collect all dependencies. + ValueToValueMapTy VMap; + // Clone definitions only for needed globals. Others will be added as + // declarations and removed later. + std::unique_ptr SubM = CloneModule( + M, VMap, [&](const GlobalValue *GV) { return GVs.count(GV); }); + // Replace entry points with cloned ones. + EntryPointSet NewEPs; + const EntryPointSet &EPs = ModuleEntryPoints.Functions; + std::for_each(EPs.begin(), EPs.end(), [&](const Function *F) { + NewEPs.insert(cast(VMap[F])); + }); + ModuleEntryPoints.Functions = std::move(NewEPs); + return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; +} + +// The function produces a copy of input LLVM IR module M with only those +// functions and globals that can be called from entry points that are specified +// in ModuleEntryPoints vector, in addition to the entry point functions. +ModuleDesc extractCallGraph(const Module &M, EntryPointGroup ModuleEntryPoints, + const DependencyGraph &DG) { + SetVector GVs; + collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG); + + ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints)); + LLVM_DEBUG(SplitM.dump()); + return SplitM; +} + +using EntryPointGroupVec = SmallVector; + +/// Module Splitter. +/// It gets a module (in a form of module descriptor, to get additional info) +/// and a collection of entry points groups. Each group specifies subset entry +/// points from input module that should be included in a split module. +class ModuleSplitter { +private: + ModuleDesc Input; + EntryPointGroupVec Groups; + DependencyGraph DG; + +private: + EntryPointGroup drawEntryPointGroup() { + assert(Groups.size() > 0 && "Reached end of entry point groups list."); + EntryPointGroup Group = std::move(Groups.back()); + Groups.pop_back(); + return Group; + } + +public: + ModuleSplitter(ModuleDesc MD, EntryPointGroupVec GroupVec) + : Input(std::move(MD)), Groups(std::move(GroupVec)), + DG(Input.getModule()) { + assert(!Groups.empty() && "Entry points groups collection is empty!"); + } + + /// Gets next subsequence of entry points in an input module and provides + /// split submodule containing these entry points and their dependencies. + ModuleDesc getNextSplit() { + return extractCallGraph(Input.getModule(), drawEntryPointGroup(), DG); + } + + /// Check that there are still submodules to split. + bool hasMoreSplits() const { return Groups.size() > 0; } +}; + +} // namespace + +static EntryPointGroupVec selectEntryPointGroups(const Module &M, + IRSplitMode Mode) { + // std::map is used here to ensure stable ordering of entry point groups, + // which is based on their contents, this greatly helps LIT tests + std::map EntryPointsMap; + + static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; + for (const auto &F : M.functions()) { + if (!isEntryPoint(F)) + continue; + + std::string Key; + switch (Mode) { + case IRSplitMode::IRSM_PER_KERNEL: + Key = F.getName(); + break; + case IRSplitMode::IRSM_PER_TU: + Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString(); + break; + case IRSplitMode::IRSM_NONE: + llvm_unreachable(""); + } + + EntryPointsMap[Key].insert(&F); + } + + EntryPointGroupVec Groups; + if (EntryPointsMap.empty()) { + // No entry points met, record this. + Groups.emplace_back("-", EntryPointSet()); + } else { + Groups.reserve(EntryPointsMap.size()); + // Start with properties of a source module + for (auto &[Key, EntryPoints] : EntryPointsMap) + Groups.emplace_back(Key, std::move(EntryPoints)); + } + + return Groups; +} + +namespace llvm { + +std::optional convertStringToSplitMode(StringRef S) { + static const StringMap Values = { + {"source", IRSplitMode::IRSM_PER_TU}, + {"kernel", IRSplitMode::IRSM_PER_KERNEL}, + {"none", IRSplitMode::IRSM_NONE}}; + + auto It = Values.find(S); + if (It == Values.end()) + return std::nullopt; + + return It->second; +} + +void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, + PostSYCLSplitCallbackType Callback) { + SmallVector OutputImages; + if (Mode == IRSplitMode::IRSM_NONE) { + auto MD = ModuleDesc(std::move(M)); + auto Symbols = MD.makeSymbolTable(); + Callback(std::move(MD.releaseModule()), std::move(Symbols)); + return; + } + + EntryPointGroupVec Groups = selectEntryPointGroups(*M, Mode); + ModuleDesc MD = std::move(M); + ModuleSplitter Splitter(std::move(MD), std::move(Groups)); + while (Splitter.hasMoreSplits()) { + ModuleDesc MD = Splitter.getNextSplit(); + auto Symbols = MD.makeSymbolTable(); + Callback(std::move(MD.releaseModule()), std::move(Symbols)); + } +} + +} // namespace llvm diff --git a/llvm/lib/Transforms/Utils/SYCLUtils.cpp b/llvm/lib/Transforms/Utils/SYCLUtils.cpp new file mode 100644 index 0000000000000..ad9864fadb828 --- /dev/null +++ b/llvm/lib/Transforms/Utils/SYCLUtils.cpp @@ -0,0 +1,26 @@ +//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// SYCL utility functions. +//===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/SYCLUtils.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { + +void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS) { + assert(!Table.empty() && "table should contain at least column titles"); + assert(!Table[0].empty() && "table should be non-empty"); + OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n"; + for (size_t I = 1, E = Table.size(); I != E; ++I) { + assert(Table[I].size() == Table[0].size() && "row's size should be equal"); + OS << join(Table[I].begin(), Table[I].end(), "|") << '\n'; + } +} + +} // namespace llvm diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll new file mode 100644 index 0000000000000..a40a52107fb0c --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll @@ -0,0 +1,17 @@ +; -- Per-kernel split +; RUN: llvm-split -sycl-split=kernel -S < %s -o %tC +; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 +; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 + +define dso_local amdgpu_kernel void @KernelA() { + ret void +} + +define dso_local amdgpu_kernel void @KernelB() { + ret void +} + +; CHECK-A0: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A0-NOT: define dso_local amdgpu_kernel void @KernelA() +; CHECK-A1-NOT: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A1: define dso_local amdgpu_kernel void @KernelA() diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll new file mode 100644 index 0000000000000..5a25e491b1b93 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll @@ -0,0 +1,75 @@ +; Check that Module splitting can trace through more complex call stacks +; involving several nested indirect calls. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B --implicit-check-not @baz +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; CHECK0-DAG: define spir_kernel void @kernel_C +; CHECK0-DAG: define spir_func i32 @bar +; CHECK0-DAG: define spir_func void @baz +; CHECK0-DAG: define spir_func void @BAZ + +; CHECK1-DAG: define spir_kernel void @kernel_B +; CHECK1-DAG: define {{.*}}spir_func i32 @foo +; CHECK1-DAG: define spir_func i32 @bar +; CHECK1-DAG: define spir_func void @baz +; CHECK1-DAG: define spir_func void @BAZ + +; CHECK2-DAG: define spir_kernel void @kernel_A +; CHECK2-DAG: define {{.*}}spir_func void @baz + +define spir_func i32 @foo(i32 (i32, void ()*)* %ptr1, void ()* %ptr2) { + %1 = call spir_func i32 %ptr1(i32 42, void ()* %ptr2) + ret i32 %1 +} + +define spir_func i32 @bar(i32 %arg, void ()* %ptr) { + call spir_func void %ptr() + ret i32 %arg +} + +define spir_func void @baz() { + ret void +} + +define spir_func void @BAZ() { + ret void +} + +define spir_kernel void @kernel_A() #0 { + call spir_func void @baz() + ret void +} + +define spir_kernel void @kernel_B() #1 { + call spir_func i32 @foo(i32 (i32, void ()*)* null, void ()* null) + ret void +} + +define spir_kernel void @kernel_C() #2 { + call spir_func i32 @bar(i32 42, void ()* null) + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } +attributes #2 = { "sycl-module-id"="TU3.cpp" } diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll new file mode 100644 index 0000000000000..c9289d78b1fda --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll @@ -0,0 +1,43 @@ +; This test checks that Module splitting can properly perform device code split by tracking +; all uses of functions (not only direct calls). + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.sym --check-prefix=CHECK-SYM0 +; RUN: FileCheck %s -input-file=%t_1.sym --check-prefix=CHECK-SYM1 +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix=CHECK-IR0 +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix=CHECK-IR1 + +; CHECK-SYM0: kernelA +; CHECK-SYM1: kernelB +; +; CHECK-IR0: define dso_local spir_kernel void @kernelA +; +; CHECK-IR1: @FuncTable = weak global ptr @func +; CHECK-IR1: define {{.*}} i32 @func +; CHECK-IR1: define weak_odr dso_local spir_kernel void @kernelB + +@FuncTable = weak global ptr @func, align 8 + +define dso_local spir_func i32 @func(i32 %a) { +entry: + ret i32 %a +} + +define weak_odr dso_local spir_kernel void @kernelB() #0 { +entry: + %0 = call i32 @indirect_call(ptr addrspace(4) addrspacecast ( ptr getelementptr inbounds ( [1 x ptr] , ptr @FuncTable, i64 0, i64 0) to ptr addrspace(4)), i32 0) + ret void +} + +define dso_local spir_kernel void @kernelA() #1 { +entry: + ret void +} + +declare dso_local spir_func i32 @indirect_call(ptr addrspace(4), i32) local_unnamed_addr + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; CHECK: kernel1 +; CHECK: kernel2 diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll new file mode 100644 index 0000000000000..b949ab7530f39 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll @@ -0,0 +1,108 @@ +; Test checks "kernel" splitting mode. + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t.files +; RUN: FileCheck %s -input-file=%t.files_0.ll --check-prefixes CHECK-MODULE0,CHECK +; RUN: FileCheck %s -input-file=%t.files_0.sym --check-prefixes CHECK-MODULE0-TXT +; RUN: FileCheck %s -input-file=%t.files_1.ll --check-prefixes CHECK-MODULE1,CHECK +; RUN: FileCheck %s -input-file=%t.files_1.sym --check-prefixes CHECK-MODULE1-TXT +; RUN: FileCheck %s -input-file=%t.files_2.ll --check-prefixes CHECK-MODULE2,CHECK +; RUN: FileCheck %s -input-file=%t.files_2.sym --check-prefixes CHECK-MODULE2-TXT + +;CHECK-MODULE0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +;CHECK-MODULE1-NOT: @GV +;CHECK-MODULE2-NOT: @GV +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-MODULE0-TXT-NOT: T0_kernelA +; CHECK-MODULE1-TXT-NOT: TU0_kernelA +; CHECK-MODULE2-TXT: TU0_kernelA + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE2: define dso_local spir_kernel void @TU0_kernelA +define dso_local spir_kernel void @TU0_kernelA() #0 { +entry: +; CHECK-MODULE2: call spir_func void @foo() + call spir_func void @foo() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE2: define {{.*}} spir_func void @foo() +define dso_local spir_func void @foo() { +entry: +; CHECK-MODULE2: call spir_func void @bar() + call spir_func void @bar() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE2: define {{.*}} spir_func void @bar() +define linkonce_odr dso_local spir_func void @bar() { +entry: + ret void +} + +; CHECK-MODULE0-TXT-NOT: TU0_kernelB +; CHECK-MODULE1-TXT: TU0_kernelB +; CHECK-MODULE2-TXT-NOT: TU0_kernelB + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE1: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU0_kernelB() +define dso_local spir_kernel void @TU0_kernelB() #0 { +entry: +; CHECK-MODULE1: call spir_func void @foo1() + call spir_func void @foo1() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo1() +; CHECK-MODULE1: define {{.*}} spir_func void @foo1() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo1() +define dso_local spir_func void @foo1() { +entry: + ret void +} + +; CHECK-MODULE0-TXT: TU1_kernel +; CHECK-MODULE1-TXT-NOT: TU1_kernel +; CHECK-MODULE2-TXT-NOT: TU1_kernel + +; CHECK-MODULE0: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU1_kernel() +define dso_local spir_kernel void @TU1_kernel() #1 { +entry: +; CHECK-MODULE0: call spir_func void @foo2() + call spir_func void @foo2() + ret void +} + +; CHECK-MODULE0: define {{.*}} spir_func void @foo2() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo2() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo2() +define dso_local spir_func void @foo2() { +entry: +; CHECK-MODULE0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK; !0 = !{i32 1, i32 2} +; CHECK; !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll new file mode 100644 index 0000000000000..6a4e543209526 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll @@ -0,0 +1,97 @@ +; Test checks that kernels are being split by attached TU metadata and +; used functions are being moved with kernels that use them. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefixes CHECK-TU0,CHECK +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefixes CHECK-TU1,CHECK +; RUN: FileCheck %s -input-file=%t_0.sym --check-prefixes CHECK-TU0-TXT +; RUN: FileCheck %s -input-file=%t_1.sym --check-prefixes CHECK-TU1-TXT + +; CHECK-TU1-NOT: @GV +; CHECK-TU0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-TU0-TXT-NOT: TU1_kernelA +; CHECK-TU1-TXT: TU1_kernelA + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelA +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelA +define dso_local spir_kernel void @TU1_kernelA() #0 { +entry: +; CHECK-TU1: call spir_func void @func1_TU1() + call spir_func void @func1_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func1_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func1_TU1() +define dso_local spir_func void @func1_TU1() { +entry: +; CHECK-TU1: call spir_func void @func2_TU1() + call spir_func void @func2_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func2_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func2_TU1() +define linkonce_odr dso_local spir_func void @func2_TU1() { +entry: + ret void +} + + +; CHECK-TU0-TXT-NOT: TU1_kernelB +; CHECK-TU1-TXT: TU1_kernelB + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelB() +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelB() +define dso_local spir_kernel void @TU1_kernelB() #0 { +entry: +; CHECK-TU1: call spir_func void @func3_TU1() + call spir_func void @func3_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func3_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func3_TU1() +define dso_local spir_func void @func3_TU1() { +entry: + ret void +} + +; CHECK-TU0-TXT: TU0_kernel +; CHECK-TU1-TXT-NOT: TU0_kernel + +; CHECK-TU0: define dso_local spir_kernel void @TU0_kernel() +; CHECK-TU1-NOT: define dso_local spir_kernel void @TU0_kernel() +define dso_local spir_kernel void @TU0_kernel() #1 { +entry: +; CHECK-TU0: call spir_func void @func_TU0() + call spir_func void @func_TU0() + ret void +} + +; CHECK-TU0: define {{.*}} spir_func void @func_TU0() +; CHECK-TU1-NOT: define {{.*}} spir_func void @func_TU0() +define dso_local spir_func void @func_TU0() { +entry: +; CHECK-TU0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK: !0 = !{i32 1, i32 2} +; CHECK: !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll new file mode 100644 index 0000000000000..1f188d8e32db6 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll @@ -0,0 +1,66 @@ +; The test checks that Module splitting does not treat declarations as entry points. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t1 +; RUN: FileCheck %s -input-file=%t1.table --check-prefix CHECK-PER-SOURCE-TABLE +; RUN: FileCheck %s -input-file=%t1_0.sym --check-prefix CHECK-PER-SOURCE-SYM0 +; RUN: FileCheck %s -input-file=%t1_1.sym --check-prefix CHECK-PER-SOURCE-SYM1 + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t2 +; RUN: FileCheck %s -input-file=%t2.table --check-prefix CHECK-PER-KERNEL-TABLE +; RUN: FileCheck %s -input-file=%t2_0.sym --check-prefix CHECK-PER-KERNEL-SYM0 +; RUN: FileCheck %s -input-file=%t2_1.sym --check-prefix CHECK-PER-KERNEL-SYM1 +; RUN: FileCheck %s -input-file=%t2_2.sym --check-prefix CHECK-PER-KERNEL-SYM2 + +; With per-source split, there should be two device images +; CHECK-PER-SOURCE-TABLE: [Code|Symbols] +; CHECK-PER-SOURCE-TABLE: {{.*}}_0.ll|{{.*}}_0.sym +; CHECK-PER-SOURCE-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym +; CHECK-PER-SOURCE-TABLE-EMPTY: +; +; CHECK-PER-SOURCE-SYM0-NOT: TU1_kernel1 +; CHECK-PER-SOURCE-SYM0: TU1_kernel0 +; CHECK-PER-SOURCE-SYM0-EMPTY: +; +; CHECK-PER-SOURCE-SYM1-NOT: TU1_kernel1 +; CHECK-PER-SOURCE-SYM1: TU0_kernel0 +; CHECK-PER-SOURCE-SYM1-NEXT: TU0_kernel1 +; CHECK-PER-SOURCE-SYM1-EMPTY: + +; With per-kernel split, there should be three device images +; CHECK-PER-KERNEL-TABLE: [Code|Symbols] +; CHECK-PER-KERNEL-TABLE: {{.*}}_0.ll|{{.*}}_0.sym +; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym +; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_2.ll|{{.*}}_2.sym +; CHECK-PER-KERNEL-TABLE-EMPTY: +; +; CHECK-PER-KERNEL-SYM0-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM0: TU1_kernel0 +; CHECK-PER-KERNEL-SYM0-EMPTY: +; +; CHECK-PER-KERNEL-SYM1-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM1: TU0_kernel1 +; CHECK-PER-KERNEL-SYM1-EMPTY: +; +; CHECK-PER-KERNEL-SYM2-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM2: TU0_kernel0 +; CHECK-PER-KERNEL-SYM2-EMPTY: + + +define spir_kernel void @TU0_kernel0() #0 { +entry: + ret void +} + +define spir_kernel void @TU0_kernel1() #0 { +entry: + ret void +} + +define spir_kernel void @TU1_kernel0() #1 { + ret void +} + +declare spir_kernel void @TU1_kernel1() #1 + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } diff --git a/llvm/tools/llvm-split/CMakeLists.txt b/llvm/tools/llvm-split/CMakeLists.txt index 1104e3145952c..b755755a984fc 100644 --- a/llvm/tools/llvm-split/CMakeLists.txt +++ b/llvm/tools/llvm-split/CMakeLists.txt @@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS Support Target TargetParser + ipo ) add_llvm_tool(llvm-split diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 9f6678a1fa466..f6e90985304d6 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -11,14 +11,19 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" @@ -27,6 +32,9 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" +#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" #include "llvm/Transforms/Utils/SplitModule.h" using namespace llvm; @@ -70,6 +78,108 @@ static cl::opt MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"), cl::value_desc("cpu"), cl::cat(SplitCategory)); +static cl::opt SYCLSplitMode( + "sycl-split", + cl::desc("SYCL Split Mode. If present, SYCL splitting algorithm is used " + "with the specified mode."), + cl::Optional, cl::init(IRSplitMode::IRSM_NONE), + cl::values(clEnumValN(IRSplitMode::IRSM_PER_TU, "source", + "1 ouptput module per translation unit"), + clEnumValN(IRSplitMode::IRSM_PER_KERNEL, "kernel", + "1 output module per kernel")), + cl::cat(SplitCategory)); + +static cl::opt OutputAssembly{ + "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)}; + +void writeStringToFile(StringRef Content, StringRef Path) { + std::error_code EC; + raw_fd_ostream OS(Path, EC); + if (EC) { + errs() << formatv("error opening file: {0}, error: {1}\n", Path, + EC.message()); + exit(1); + } + + OS << Content << "\n"; +} + +void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { + int FD = -1; + if (std::error_code EC = sys::fs::openFileForWrite(Path, FD)) { + errs() << formatv("error opening file: {0}, error: {1}", Path, EC.message()) + << '\n'; + exit(1); + } + + raw_fd_ostream OS(FD, /*ShouldClose*/ true); + if (OutputAssembly) + M.print(OS, /*AssemblyAnnotationWriter*/ nullptr); + else + WriteBitcodeToFile(M, OS); +} + +void writeSplitModulesAsTable(ArrayRef Modules, + StringRef Path) { + SmallVector> Columns; + Columns.emplace_back("Code"); + Columns.emplace_back("Symbols"); + + SYCLStringTable Table; + Table.emplace_back(std::move(Columns)); + for (const auto &[I, SM] : enumerate(Modules)) { + SmallString<128> SymbolsFile; + (Twine(Path) + "_" + Twine(I) + ".sym").toVector(SymbolsFile); + writeStringToFile(SM.Symbols, SymbolsFile); + SmallVector> Row; + Row.emplace_back(SM.ModuleFilePath); + Row.emplace_back(SymbolsFile); + Table.emplace_back(std::move(Row)); + } + + std::error_code EC; + raw_fd_ostream OS((Path + ".table").str(), EC); + if (EC) { + errs() << formatv("error opening file: {0}\n", Path); + exit(1); + } + + writeSYCLStringTable(Table, OS); +} + +void cleanupModule(Module &M) { + ModuleAnalysisManager MAM; + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + ModulePassManager MPM; + MPM.addPass(GlobalDCEPass()); // Delete unreachable globals. + MPM.run(M, MAM); +} + +Error runSYCLSplitModule(std::unique_ptr M) { + SmallVector SplitModules; + auto PostSYCLSplitCallback = [&](std::unique_ptr MPart, + std::string Symbols) { + if (verifyModule(*MPart)) { + errs() << "Broken Module!\n"; + exit(1); + } + + // TODO: DCE is a crucial pass in a SYCL post-link pipeline. + // At the moment, LIT checking can't be perfomed without DCE. + cleanupModule(*MPart); + size_t ID = SplitModules.size(); + StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc"; + std::string ModulePath = + (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str(); + writeModuleToFile(*MPart, ModulePath, OutputAssembly); + SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); + }; + + SYCLSplitModule(std::move(M), SYCLSplitMode, PostSYCLSplitCallback); + writeSplitModulesAsTable(SplitModules, OutputFilename); + return Error::success(); +} + int main(int argc, char **argv) { InitLLVM X(argc, argv); @@ -123,6 +233,17 @@ int main(int argc, char **argv) { Out->keep(); }; + if (SYCLSplitMode != IRSplitMode::IRSM_NONE) { + auto E = runSYCLSplitModule(std::move(M)); + if (E) { + errs() << E << "\n"; + Err.print(argv[0], errs()); + return 1; + } + + return 0; + } + if (TM) { if (PreserveLocals) { errs() << "warning: --preserve-locals has no effect when using "