Skip to content

Commit 7d97996

Browse files
committed
[SYCL][RTC] Preliminary support for ESIMD kernels
Signed-off-by: Julian Oppermann <[email protected]>
1 parent 814290d commit 7d97996

File tree

8 files changed

+264
-32
lines changed

8 files changed

+264
-32
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ struct RTCBundleInfo {
403403
RTCBundleBinaryInfo BinaryInfo;
404404
FrozenSymbolTable SymbolTable;
405405
FrozenPropertyRegistry Properties;
406+
407+
RTCBundleInfo() = default;
408+
RTCBundleInfo(RTCBundleInfo &&) = default;
409+
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
406410
};
407411

408412
} // namespace jit_compiler

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
88
lib/fusion/JITContext.cpp
99
lib/fusion/ModuleHelper.cpp
1010
lib/rtc/DeviceCompilation.cpp
11+
lib/rtc/ESIMD.cpp
1112
lib/helper/ConfigHelper.cpp
1213

1314
SHARED
@@ -32,6 +33,7 @@ add_llvm_library(sycl-jit
3233
TargetParser
3334
MC
3435
SYCLLowerIR
36+
GenXIntrinsics
3537
${LLVM_TARGETS_TO_BUILD}
3638

3739
LINK_LIBS

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,13 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
258258
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
259259
}
260260

261-
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
262-
if (!BundleInfoOrError) {
263-
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
261+
auto PostLinkResultOrError = performPostLink(std::move(Module), UserArgList);
262+
if (!PostLinkResultOrError) {
263+
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
264264
"Post-link phase failed");
265265
}
266-
auto BundleInfo = std::move(*BundleInfoOrError);
266+
RTCBundleInfo BundleInfo;
267+
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);
267268

268269
auto BinaryInfoOrError =
269270
translation::KernelTranslator::translateBundleToSPIRV(

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "DeviceCompilation.h"
10+
#include "ESIMD.h"
1011

1112
#include <clang/Basic/DiagnosticDriver.h>
1213
#include <clang/Basic/Version.h>
@@ -23,6 +24,8 @@
2324
#include <llvm/IRReader/IRReader.h>
2425
#include <llvm/Linker/Linker.h>
2526
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
27+
#include <llvm/SYCLLowerIR/ESIMD/LowerESIMD.h>
28+
#include <llvm/SYCLLowerIR/LowerInvokeSimd.h>
2629
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
2730
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
2831
#include <llvm/Support/PropertySetIO.h>
@@ -376,42 +379,82 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
376379
return !Res.areAllPreserved();
377380
}
378381

379-
Expected<RTCBundleInfo> jit_compiler::performPostLink(
380-
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
382+
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
383+
std::unique_ptr<llvm::Module> Module,
384+
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
381385
// This is a simplified version of `processInputModule` in
382386
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
383387
// left out of the algorithm for now.
384388

385-
assert(!Module.getGlobalVariable("llvm.used") &&
386-
!Module.getGlobalVariable("llvm.compiler.used"));
389+
// TODO: SplitMode can be controlled by the user.
390+
const auto SplitMode = SPLIT_NONE;
391+
392+
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
393+
// `shouldEmitOnlyKernelsAsEntryPoints` in
394+
// `clang/lib/Driver/ToolChains/Clang.cpp`.
395+
const bool EmitOnlyKernelsAsEntryPoints = true;
396+
397+
// TODO: The optlevel passed to `sycl-post-link` is determined by
398+
// `getSYCLPostLinkOptimizationLevel` in
399+
// `clang/lib/Driver/ToolChains/Clang.cpp`.
400+
const bool PerformOpts = true;
401+
402+
// Propagate ESIMD attribute to wrapper functions to prevent spurious splits
403+
// and kernel link errors.
404+
runModulePass<SYCLFixupESIMDKernelWrapperMDPass>(*Module);
405+
406+
assert(!Module->getGlobalVariable("llvm.used") &&
407+
!Module->getGlobalVariable("llvm.compiler.used"));
387408
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
388409
// `removeDeviceGlobalFromCompilerUsed` methods.
389410

390-
assert(!isModuleUsingAsan(Module));
411+
assert(!isModuleUsingAsan(*Module));
391412
// Otherwise: Need to instrument each image scope device globals if the module
392413
// has been instrumented by sanitizer pass.
393414

394415
// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
395416
// LLVM IR specification.
396-
runModulePass<SYCLJointMatrixTransformPass>(Module);
417+
runModulePass<SYCLJointMatrixTransformPass>(*Module);
418+
419+
// Do invoke_simd processing before splitting because this:
420+
// - saves processing time (the pass is run once, even though on larger IR)
421+
// - doing it before SYCL/ESIMD splitting is required for correctness
422+
if (runModulePass<SYCLLowerInvokeSimdPass>(*Module)) {
423+
return createStringError("`invoke_simd` calls detected");
424+
}
397425

398426
// TODO: Implement actual device code splitting. We're just using the splitter
399427
// to obtain additional information about the module for now.
400-
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
401-
// `shouldEmitOnlyKernelsAsEntryPoints` in
402-
// `clang/lib/Driver/ToolChains/Clang.cpp`.
428+
403429
std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
404-
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
405-
/*IROutputOnly=*/false,
406-
/*EmitOnlyKernelsAsEntryPoints=*/true);
407-
assert(Splitter->remainingSplits() == 1);
430+
ModuleDesc{std::move(Module)}, SplitMode,
431+
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
432+
assert(Splitter->hasMoreSplits());
433+
if (Splitter->remainingSplits() > 1) {
434+
return createStringError("Device code requires splitting");
435+
}
408436

409437
// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
410438
// be processed.
411439

412-
assert(Splitter->hasMoreSplits());
413440
ModuleDesc MDesc = Splitter->nextSplit();
414-
assert(&Module == &MDesc.getModule());
441+
442+
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
443+
// `invoke_simd` is supported.
444+
445+
SmallVector<ModuleDesc, 2> ESIMDSplits =
446+
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
447+
assert(!ESIMDSplits.empty());
448+
if (ESIMDSplits.size() > 1) {
449+
return createStringError("Mixing SYCL and ESIMD code is unsupported");
450+
}
451+
MDesc = std::move(ESIMDSplits.front());
452+
453+
if (MDesc.isESIMD()) {
454+
// TODO: We're assuming ESIMD lowering is not deactivated (why would it?).
455+
lowerEsimdConstructs(MDesc, PerformOpts);
456+
}
457+
415458
MDesc.saveSplitInformationAsMetadata();
416459

417460
RTCBundleInfo BundleInfo;
@@ -448,10 +491,7 @@ Expected<RTCBundleInfo> jit_compiler::performPostLink(
448491
}
449492
};
450493

451-
// Regain ownership of the module.
452-
MDesc.releaseModulePtr().release();
453-
454-
return std::move(BundleInfo);
494+
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
455495
}
456496

457497
Expected<InputArgList>
@@ -513,11 +553,9 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
513553
return createStringError("Device code splitting is not yet supported");
514554
}
515555

516-
if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
517-
OPT_fno_sycl_device_code_split_esimd)) {
518-
// TODO: There are more ESIMD-related options.
519-
return createStringError(
520-
"Runtime compilation of ESIMD kernels is not yet supported");
556+
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
557+
OPT_fno_sycl_device_code_split_esimd, true)) {
558+
return createStringError("ESIMD device code split cannot be deactivated");
521559
}
522560

523561
if (AL.hasFlag(OPT_fsycl_dead_args_optimization,

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
2727
llvm::Error linkDeviceLibraries(llvm::Module &Module,
2828
const llvm::opt::InputArgList &UserArgList);
2929

30-
llvm::Expected<RTCBundleInfo>
31-
performPostLink(llvm::Module &Module,
30+
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
31+
llvm::Expected<PostLinkResult>
32+
performPostLink(std::unique_ptr<llvm::Module> Module,
3233
const llvm::opt::InputArgList &UserArgList);
3334

3435
llvm::Expected<llvm::opt::InputArgList>
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===------------- ESIMD.cpp - Driver for ESIMD lowering ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "ESIMD.h"
10+
11+
#include "llvm/Analysis/CGSCCPassManager.h"
12+
#include "llvm/Analysis/LoopAnalysisManager.h"
13+
#include "llvm/GenXIntrinsics/GenXSPIRVWriterAdaptor.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/IR/PassManager.h"
16+
#include "llvm/Passes/PassBuilder.h"
17+
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
18+
#include "llvm/Transforms/InstCombine/InstCombine.h"
19+
#include "llvm/Transforms/Scalar/DCE.h"
20+
#include "llvm/Transforms/Scalar/EarlyCSE.h"
21+
#include "llvm/Transforms/Scalar/SROA.h"
22+
23+
using namespace llvm;
24+
25+
using string_vector = std::vector<std::string>;
26+
27+
// When ESIMD code was separated from the regular SYCL code,
28+
// we can safely process ESIMD part.
29+
void jit_compiler::lowerEsimdConstructs(module_split::ModuleDesc &MD,
30+
bool PerformOpts) {
31+
LoopAnalysisManager LAM;
32+
CGSCCAnalysisManager CGAM;
33+
FunctionAnalysisManager FAM;
34+
ModuleAnalysisManager MAM;
35+
36+
PassBuilder PB;
37+
PB.registerModuleAnalyses(MAM);
38+
PB.registerCGSCCAnalyses(CGAM);
39+
PB.registerFunctionAnalyses(FAM);
40+
PB.registerLoopAnalyses(LAM);
41+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
42+
43+
ModulePassManager MPM;
44+
MPM.addPass(SYCLLowerESIMDPass(/*ModuleContainsScalar=*/false));
45+
46+
if (PerformOpts) {
47+
FunctionPassManager FPM;
48+
FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
49+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
50+
}
51+
MPM.addPass(ESIMDOptimizeVecArgCallConvPass{});
52+
FunctionPassManager MainFPM;
53+
MainFPM.addPass(ESIMDLowerLoadStorePass{});
54+
55+
if (!PerformOpts) {
56+
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
57+
MainFPM.addPass(EarlyCSEPass(true));
58+
MainFPM.addPass(InstCombinePass{});
59+
MainFPM.addPass(DCEPass{});
60+
// TODO: maybe remove some passes below that don't affect code quality
61+
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
62+
MainFPM.addPass(EarlyCSEPass(true));
63+
MainFPM.addPass(InstCombinePass{});
64+
MainFPM.addPass(DCEPass{});
65+
}
66+
MPM.addPass(ESIMDLowerSLMReservationCalls{});
67+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(MainFPM)));
68+
MPM.addPass(GenXSPIRVWriterAdaptor(/*RewriteTypes=*/true,
69+
/*RewriteSingleElementVectorsIn*/ false));
70+
// GenXSPIRVWriterAdaptor pass replaced some functions with "rewritten"
71+
// versions so the entry point table must be rebuilt.
72+
// TODO Change entry point search to analysis?
73+
std::vector<std::string> Names;
74+
MD.saveEntryPointNames(Names);
75+
MPM.run(MD.getModule(), MAM);
76+
MD.rebuildEntryPoints(Names);
77+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===-------------- ESIMD.h - Driver for ESIMD lowering -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_JIT_COMPILER_RTC_ESIMD_H
10+
#define SYCL_JIT_COMPILER_RTC_ESIMD_H
11+
12+
#include "llvm/SYCLLowerIR/ModuleSplitter.h"
13+
14+
namespace jit_compiler {
15+
16+
// Runs a pass pipeline to lower ESIMD constructs on the given split model,
17+
// which may only contain ESIMD entrypoints. This is a copy of the similar
18+
// function in `sycl-post-link`.
19+
void lowerEsimdConstructs(llvm::module_split::ModuleDesc &MD, bool PerformOpts);
20+
21+
} // namespace jit_compiler
22+
23+
#endif // SYCL_JIT_COMPILER_RTC_ESIMD_H

0 commit comments

Comments
 (0)