@@ -60,11 +60,6 @@ constexpr char SYCL_SCOPE_NAME[] = "<SYCL>";
6060constexpr char ESIMD_SCOPE_NAME[] = " <ESIMD>" ;
6161constexpr char ESIMD_MARKER_MD[] = " sycl_explicit_simd" ;
6262
63- cl::opt<bool > AllowDeviceImageDependencies{
64- " allow-device-image-dependencies" ,
65- cl::desc (" Allow dependencies between device images" ),
66- cl::cat (getModuleSplitCategory ()), cl::init (false )};
67-
6863EntryPointsGroupScope selectDeviceCodeGroupScope (const Module &M,
6964 IRSplitMode Mode,
7065 bool AutoSplitIsGlobalScope) {
@@ -178,7 +173,7 @@ class DependencyGraph {
178173public:
179174 using GlobalSet = SmallPtrSet<const GlobalValue *, 16 >;
180175
181- DependencyGraph (const Module &M) {
176+ DependencyGraph (const Module &M, bool AllowDeviceImageDependencies ) {
182177 // Group functions by their signature to handle case (2) described above
183178 DenseMap<const FunctionType *, DependencyGraph::GlobalSet>
184179 FuncTypeToFuncsMap;
@@ -196,7 +191,7 @@ class DependencyGraph {
196191 }
197192
198193 for (const auto &F : M.functions ()) {
199- if (canBeImportedFunction (F))
194+ if (canBeImportedFunction (F, AllowDeviceImageDependencies ))
200195 continue ;
201196
202197 // case (1), see comment above the class definition
@@ -311,7 +306,9 @@ static bool isIntrinsicOrBuiltin(const Function &F) {
311306}
312307
313308// Checks for use of undefined user functions and emits a warning message.
314- static void checkForCallsToUndefinedFunctions (const Module &M) {
309+ static void
310+ checkForCallsToUndefinedFunctions (const Module &M,
311+ bool AllowDeviceImageDependencies) {
315312 if (AllowDeviceImageDependencies)
316313 return ;
317314 for (const Function &F : M) {
@@ -391,11 +388,11 @@ ModuleDesc extractSubModule(const ModuleDesc &MD,
391388// The function produces a copy of input LLVM IR module M with only those
392389// functions and globals that can be called from entry points that are specified
393390// in ModuleEntryPoints vector, in addition to the entry point functions.
394- ModuleDesc extractCallGraph (const ModuleDesc &MD,
395- EntryPointGroup &&ModuleEntryPoints,
396- const DependencyGraph &CG,
397- const std::function<bool (const Function *)>
398- &IncludeFunctionPredicate = nullptr) {
391+ ModuleDesc extractCallGraph (
392+ const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
393+ const DependencyGraph &CG, bool AllowDeviceImageDependencies ,
394+ const std::function<bool (const Function *)> &IncludeFunctionPredicate =
395+ nullptr) {
399396 SetVector<const GlobalValue *> GVs;
400397 collectFunctionsAndGlobalVariablesToExtract (
401398 GVs, MD.getModule (), ModuleEntryPoints, CG, IncludeFunctionPredicate);
@@ -405,20 +402,21 @@ ModuleDesc extractCallGraph(const ModuleDesc &MD,
405402 // sycl-post-link. This call is redundant. However, we subsequently run
406403 // GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
407404 // can be removed once that pass no longer depends on this cleanup.
408- SplitM.cleanup ();
409- checkForCallsToUndefinedFunctions (SplitM.getModule ());
405+ SplitM.cleanup (AllowDeviceImageDependencies);
406+ checkForCallsToUndefinedFunctions (SplitM.getModule (),
407+ AllowDeviceImageDependencies);
410408
411409 return SplitM;
412410}
413411
414412// The function is similar to 'extractCallGraph', but it produces a copy of
415413// input LLVM IR module M with _all_ ESIMD functions and kernels included,
416414// regardless of whether or not they are listed in ModuleEntryPoints.
417- ModuleDesc extractESIMDSubModule (const ModuleDesc &MD,
418- EntryPointGroup &&ModuleEntryPoints,
419- const DependencyGraph &CG,
420- const std::function<bool (const Function *)>
421- &IncludeFunctionPredicate = nullptr) {
415+ ModuleDesc extractESIMDSubModule (
416+ const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
417+ const DependencyGraph &CG, bool AllowDeviceImageDependencies ,
418+ const std::function<bool (const Function *)> &IncludeFunctionPredicate =
419+ nullptr) {
422420 SetVector<const GlobalValue *> GVs;
423421 for (const auto &F : MD.getModule ().functions ())
424422 if (isESIMDFunction (F))
@@ -432,7 +430,7 @@ ModuleDesc extractESIMDSubModule(const ModuleDesc &MD,
432430 // sycl-post-link. This call is redundant. However, we subsequently run
433431 // GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
434432 // can be removed once that pass no longer depends on this cleanup.
435- SplitM.cleanup ();
433+ SplitM.cleanup (AllowDeviceImageDependencies );
436434
437435 return SplitM;
438436}
@@ -449,19 +447,22 @@ class ModuleCopier : public ModuleSplitterBase {
449447 // sycl-post-link. This call is redundant. However, we subsequently run
450448 // GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup
451449 // call can be removed once that pass no longer depends on this cleanup.
452- Desc.cleanup ();
450+ Desc.cleanup (AllowDeviceImageDependencies );
453451 return Desc;
454452 }
455453};
456454
457455class ModuleSplitter : public ModuleSplitterBase {
458456public:
459- ModuleSplitter (ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
460- : ModuleSplitterBase(std::move(MD), std::move(GroupVec)),
461- CG (Input.getModule()) {}
457+ ModuleSplitter (ModuleDesc &&MD, EntryPointGroupVec &&GroupVec,
458+ bool AllowDeviceImageDependencies)
459+ : ModuleSplitterBase(std::move(MD), std::move(GroupVec),
460+ AllowDeviceImageDependencies),
461+ CG (Input.getModule(), AllowDeviceImageDependencies) {}
462462
463463 ModuleDesc nextSplit () override {
464- return extractCallGraph (Input, nextGroup (), CG);
464+ return extractCallGraph (Input, nextGroup (), CG,
465+ AllowDeviceImageDependencies);
465466 }
466467
467468private:
@@ -489,11 +490,6 @@ bool isESIMDFunction(const Function &F) {
489490 return F.getMetadata (ESIMD_MARKER_MD) != nullptr ;
490491}
491492
492- cl::OptionCategory &getModuleSplitCategory () {
493- static cl::OptionCategory ModuleSplitCategory{" Module Split options" };
494- return ModuleSplitCategory;
495- }
496-
497493Error ModuleSplitterBase::verifyNoCrossModuleDeviceGlobalUsage () {
498494 const Module &M = getInputModule ();
499495 // Early exit if there is only one group
@@ -692,7 +688,8 @@ void ModuleDesc::restoreLinkageOfDirectInvokeSimdTargets() {
692688// tries to internalize absolutely everything. This function serves as "input
693689// from a linker" that tells the pass what must be preserved in order to make
694690// the transformation safe.
695- static bool mustPreserveGV (const GlobalValue &GV) {
691+ static bool mustPreserveGV (const GlobalValue &GV,
692+ bool AllowDeviceImageDependencies) {
696693 if (const Function *F = dyn_cast<Function>(&GV)) {
697694 // When dynamic linking is supported, we internalize everything (except
698695 // kernels which are the entry points from host code to device code) that
@@ -703,7 +700,8 @@ static bool mustPreserveGV(const GlobalValue &GV) {
703700 const bool SpirOrGPU = CC == CallingConv::SPIR_KERNEL ||
704701 CC == CallingConv::AMDGPU_KERNEL ||
705702 CC == CallingConv::PTX_Kernel;
706- return SpirOrGPU || canBeImportedFunction (*F);
703+ return SpirOrGPU ||
704+ canBeImportedFunction (*F, AllowDeviceImageDependencies);
707705 }
708706
709707 // Otherwise, we are being even more aggressive: SYCL modules are expected
@@ -754,7 +752,7 @@ void cleanupSYCLRegisteredKernels(Module *M) {
754752
755753// TODO: try to move all passes (cleanup, spec consts, compile time properties)
756754// in one place and execute MPM.run() only once.
757- void ModuleDesc::cleanup () {
755+ void ModuleDesc::cleanup (bool AllowDeviceImageDependencies ) {
758756 // Any definitions of virtual functions should be removed and turned into
759757 // declarations, they are supposed to be provided by a different module.
760758 if (!EntryPoints.Props .HasVirtualFunctionDefinitions ) {
@@ -781,7 +779,10 @@ void ModuleDesc::cleanup() {
781779 MAM.registerPass ([&] { return PassInstrumentationAnalysis (); });
782780 ModulePassManager MPM;
783781 // Do cleanup.
784- MPM.addPass (InternalizePass (mustPreserveGV));
782+ MPM.addPass (
783+ InternalizePass ([AllowDeviceImageDependencies](const GlobalValue &GV) {
784+ return mustPreserveGV (GV, AllowDeviceImageDependencies);
785+ }));
785786 MPM.addPass (GlobalDCEPass ()); // Delete unreachable globals.
786787 MPM.addPass (StripDeadDebugInfoPass ()); // Remove dead debug info.
787788 MPM.addPass (StripDeadPrototypesPass ()); // Remove dead func decls.
@@ -1157,7 +1158,8 @@ std::string FunctionsCategorizer::computeCategoryFor(Function *F) const {
11571158
11581159std::unique_ptr<ModuleSplitterBase>
11591160getDeviceCodeSplitter (ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
1160- bool EmitOnlyKernelsAsEntryPoints) {
1161+ bool EmitOnlyKernelsAsEntryPoints,
1162+ bool AllowDeviceImageDependencies) {
11611163 FunctionsCategorizer Categorizer;
11621164
11631165 EntryPointsGroupScope Scope =
@@ -1252,9 +1254,11 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
12521254 (Groups.size () > 1 || !Groups.cbegin ()->Functions .empty ()));
12531255
12541256 if (DoSplit)
1255- return std::make_unique<ModuleSplitter>(std::move (MD), std::move (Groups));
1257+ return std::make_unique<ModuleSplitter>(std::move (MD), std::move (Groups),
1258+ AllowDeviceImageDependencies);
12561259
1257- return std::make_unique<ModuleCopier>(std::move (MD), std::move (Groups));
1260+ return std::make_unique<ModuleCopier>(std::move (MD), std::move (Groups),
1261+ AllowDeviceImageDependencies);
12581262}
12591263
12601264// Splits input module into two:
@@ -1277,7 +1281,8 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
12771281// avoid undefined behavior at later stages. That is done at higher level,
12781282// outside of this function.
12791283SmallVector<ModuleDesc, 2 > splitByESIMD (ModuleDesc &&MD,
1280- bool EmitOnlyKernelsAsEntryPoints) {
1284+ bool EmitOnlyKernelsAsEntryPoints,
1285+ bool AllowDeviceImageDependencies) {
12811286
12821287 SmallVector<module_split::ModuleDesc, 2 > Result;
12831288 EntryPointGroupVec EntryPointGroups{};
@@ -1320,12 +1325,13 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
13201325 return Result;
13211326 }
13221327
1323- DependencyGraph CG (MD.getModule ());
1328+ DependencyGraph CG (MD.getModule (), AllowDeviceImageDependencies );
13241329 for (auto &Group : EntryPointGroups) {
13251330 if (Group.isEsimd ()) {
13261331 // For ESIMD module, we use full call graph of all entry points and all
13271332 // ESIMD functions.
1328- Result.emplace_back (extractESIMDSubModule (MD, std::move (Group), CG));
1333+ Result.emplace_back (extractESIMDSubModule (MD, std::move (Group), CG,
1334+ AllowDeviceImageDependencies));
13291335 } else {
13301336 // For non-ESIMD module we only use non-ESIMD functions. Additional filter
13311337 // is needed, because there could be uses of ESIMD functions from
@@ -1334,7 +1340,7 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
13341340 // were processed and therefore it is fine to return an "incomplete"
13351341 // module here.
13361342 Result.emplace_back (extractCallGraph (
1337- MD, std::move (Group), CG,
1343+ MD, std::move (Group), CG, AllowDeviceImageDependencies,
13381344 [=](const Function *F) -> bool { return !isESIMDFunction (*F); }));
13391345 }
13401346 }
@@ -1477,7 +1483,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
14771483 // FIXME: false arguments are temporary for now.
14781484 auto Splitter = getDeviceCodeSplitter (std::move (MD), Settings.Mode ,
14791485 /* IROutputOnly=*/ false ,
1480- /* EmitOnlyKernelsAsEntryPoints=*/ false );
1486+ /* EmitOnlyKernelsAsEntryPoints=*/ false ,
1487+ Settings.AllowDeviceImageDependencies );
14811488
14821489 size_t ID = 0 ;
14831490 std::vector<SplitModule> OutputImages;
@@ -1498,7 +1505,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
14981505 return OutputImages;
14991506}
15001507
1501- bool canBeImportedFunction (const Function &F) {
1508+ bool canBeImportedFunction (const Function &F,
1509+ bool AllowDeviceImageDependencies) {
15021510
15031511 // We use sycl dynamic library mechanism to involve bf16 devicelib when
15041512 // necessary, all __devicelib_* functions from native or fallback bf16
0 commit comments