Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ struct RTCBundleInfo {
RTCBundleBinaryInfo BinaryInfo;
FrozenSymbolTable SymbolTable;
FrozenPropertyRegistry Properties;

RTCBundleInfo() = default;
RTCBundleInfo(RTCBundleInfo &&) = default;
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
};

} // namespace jit_compiler
Expand Down
2 changes: 2 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/rtc/DeviceCompilation.cpp
lib/rtc/ESIMD.cpp
lib/helper/ConfigHelper.cpp

SHARED
Expand All @@ -32,6 +33,7 @@ add_llvm_library(sycl-jit
TargetParser
MC
SYCLLowerIR
GenXIntrinsics
${LLVM_TARGETS_TO_BUILD}

LINK_LIBS
Expand Down
9 changes: 5 additions & 4 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,13 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

auto BundleInfoOrError = performPostLink(*Module, UserArgList);
if (!BundleInfoOrError) {
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
auto PostLinkResultOrError = performPostLink(std::move(Module), UserArgList);
if (!PostLinkResultOrError) {
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
"Post-link phase failed");
}
auto BundleInfo = std::move(*BundleInfoOrError);
RTCBundleInfo BundleInfo;
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);

auto BinaryInfoOrError =
translation::KernelTranslator::translateBundleToSPIRV(
Expand Down
86 changes: 62 additions & 24 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "DeviceCompilation.h"
#include "ESIMD.h"

#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
Expand All @@ -23,6 +24,8 @@
#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
#include <llvm/SYCLLowerIR/ESIMD/LowerESIMD.h>
#include <llvm/SYCLLowerIR/LowerInvokeSimd.h>
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
#include <llvm/Support/PropertySetIO.h>
Expand Down Expand Up @@ -376,42 +379,82 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
return !Res.areAllPreserved();
}

Expected<RTCBundleInfo> jit_compiler::performPostLink(
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
std::unique_ptr<llvm::Module> Module,
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
// This is a simplified version of `processInputModule` in
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
// left out of the algorithm for now.

assert(!Module.getGlobalVariable("llvm.used") &&
!Module.getGlobalVariable("llvm.compiler.used"));
// TODO: SplitMode can be controlled by the user.
const auto SplitMode = SPLIT_NONE;

// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
// `shouldEmitOnlyKernelsAsEntryPoints` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.
const bool EmitOnlyKernelsAsEntryPoints = true;

// TODO: The optlevel passed to `sycl-post-link` is determined by
// `getSYCLPostLinkOptimizationLevel` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.
const bool PerformOpts = true;

// Propagate ESIMD attribute to wrapper functions to prevent spurious splits
// and kernel link errors.
runModulePass<SYCLFixupESIMDKernelWrapperMDPass>(*Module);

assert(!Module->getGlobalVariable("llvm.used") &&
!Module->getGlobalVariable("llvm.compiler.used"));
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
// `removeDeviceGlobalFromCompilerUsed` methods.

assert(!isModuleUsingAsan(Module));
assert(!isModuleUsingAsan(*Module));
// Otherwise: Need to instrument each image scope device globals if the module
// has been instrumented by sanitizer pass.

// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
// LLVM IR specification.
runModulePass<SYCLJointMatrixTransformPass>(Module);
runModulePass<SYCLJointMatrixTransformPass>(*Module);

// Do invoke_simd processing before splitting because this:
// - saves processing time (the pass is run once, even though on larger IR)
// - doing it before SYCL/ESIMD splitting is required for correctness
if (runModulePass<SYCLLowerInvokeSimdPass>(*Module)) {
return createStringError("`invoke_simd` calls detected");
}

// TODO: Implement actual device code splitting. We're just using the splitter
// to obtain additional information about the module for now.
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
// `shouldEmitOnlyKernelsAsEntryPoints` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.

std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
/*IROutputOnly=*/false,
/*EmitOnlyKernelsAsEntryPoints=*/true);
assert(Splitter->remainingSplits() == 1);
ModuleDesc{std::move(Module)}, SplitMode,
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
assert(Splitter->hasMoreSplits());
if (Splitter->remainingSplits() > 1) {
return createStringError("Device code requires splitting");
}

// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
// be processed.

assert(Splitter->hasMoreSplits());
ModuleDesc MDesc = Splitter->nextSplit();
assert(&Module == &MDesc.getModule());

// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
// `invoke_simd` is supported.

SmallVector<ModuleDesc, 2> ESIMDSplits =
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
assert(!ESIMDSplits.empty());
if (ESIMDSplits.size() > 1) {
return createStringError("Mixing SYCL and ESIMD code is unsupported");
}
MDesc = std::move(ESIMDSplits.front());

if (MDesc.isESIMD()) {
// TODO: We're assuming ESIMD lowering is not deactivated (why would it?).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it be deactivated? Through a user option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The driver would deactivate it (i.e. not pass -lower-esimd to sycl-post-link) in the IR-only output mode, which is only relevant for spec constant processing IIUC. Otherwise it doesn't seem to be influenced by a user-option. I updated the comment to reflect this.

lowerEsimdConstructs(MDesc, PerformOpts);
}

MDesc.saveSplitInformationAsMetadata();

RTCBundleInfo BundleInfo;
Expand Down Expand Up @@ -448,10 +491,7 @@ Expected<RTCBundleInfo> jit_compiler::performPostLink(
}
};

// Regain ownership of the module.
MDesc.releaseModulePtr().release();

return std::move(BundleInfo);
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
}

Expected<InputArgList>
Expand Down Expand Up @@ -513,11 +553,9 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
return createStringError("Device code splitting is not yet supported");
}

if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
OPT_fno_sycl_device_code_split_esimd)) {
// TODO: There are more ESIMD-related options.
return createStringError(
"Runtime compilation of ESIMD kernels is not yet supported");
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
OPT_fno_sycl_device_code_split_esimd, true)) {
return createStringError("ESIMD device code split cannot be deactivated");
}

if (AL.hasFlag(OPT_fsycl_dead_args_optimization,
Expand Down
5 changes: 3 additions & 2 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<RTCBundleInfo>
performPostLink(llvm::Module &Module,
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
llvm::Expected<PostLinkResult>
performPostLink(std::unique_ptr<llvm::Module> Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<llvm::opt::InputArgList>
Expand Down
77 changes: 77 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/ESIMD.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===------------- ESIMD.cpp - Driver for ESIMD lowering ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "ESIMD.h"

#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/GenXIntrinsics/GenXSPIRVWriterAdaptor.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/Scalar/SROA.h"

using namespace llvm;

using string_vector = std::vector<std::string>;

// When ESIMD code was separated from the regular SYCL code,
// we can safely process ESIMD part.
void jit_compiler::lowerEsimdConstructs(module_split::ModuleDesc &MD,
bool PerformOpts) {
LoopAnalysisManager LAM;
CGSCCAnalysisManager CGAM;
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;

PassBuilder PB;
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

ModulePassManager MPM;
MPM.addPass(SYCLLowerESIMDPass(/*ModuleContainsScalar=*/false));

if (PerformOpts) {
FunctionPassManager FPM;
FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
MPM.addPass(ESIMDOptimizeVecArgCallConvPass{});
FunctionPassManager MainFPM;
MainFPM.addPass(ESIMDLowerLoadStorePass{});

if (!PerformOpts) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the negation correct here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!

MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MainFPM.addPass(EarlyCSEPass(true));
MainFPM.addPass(InstCombinePass{});
MainFPM.addPass(DCEPass{});
// TODO: maybe remove some passes below that don't affect code quality
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MainFPM.addPass(EarlyCSEPass(true));
MainFPM.addPass(InstCombinePass{});
MainFPM.addPass(DCEPass{});
}
MPM.addPass(ESIMDLowerSLMReservationCalls{});
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(MainFPM)));
MPM.addPass(GenXSPIRVWriterAdaptor(/*RewriteTypes=*/true,
/*RewriteSingleElementVectorsIn*/ false));
// GenXSPIRVWriterAdaptor pass replaced some functions with "rewritten"
// versions so the entry point table must be rebuilt.
// TODO Change entry point search to analysis?
std::vector<std::string> Names;
MD.saveEntryPointNames(Names);
MPM.run(MD.getModule(), MAM);
MD.rebuildEntryPoints(Names);
}
23 changes: 23 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/ESIMD.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===-------------- ESIMD.h - Driver for ESIMD lowering -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_JIT_COMPILER_RTC_ESIMD_H
#define SYCL_JIT_COMPILER_RTC_ESIMD_H

#include "llvm/SYCLLowerIR/ModuleSplitter.h"

namespace jit_compiler {

// Runs a pass pipeline to lower ESIMD constructs on the given split model,
// which may only contain ESIMD entrypoints. This is a copy of the similar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may or must?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must

// function in `sycl-post-link`.
void lowerEsimdConstructs(llvm::module_split::ModuleDesc &MD, bool PerformOpts);

} // namespace jit_compiler

#endif // SYCL_JIT_COMPILER_RTC_ESIMD_H
Loading
Loading