|
| 1 | +//===-------- SplitModuleByCategory.cpp - split a module by categories ----===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// See comments in the header. |
| 9 | +//===----------------------------------------------------------------------===// |
| 10 | + |
| 11 | +#include "llvm/Transforms/Utils/SplitModuleByCategory.h" |
| 12 | +#include "llvm/ADT/SetVector.h" |
| 13 | +#include "llvm/ADT/SmallPtrSet.h" |
| 14 | +#include "llvm/ADT/StringExtras.h" |
| 15 | +#include "llvm/IR/Constants.h" |
| 16 | +#include "llvm/IR/Function.h" |
| 17 | +#include "llvm/IR/InstIterator.h" |
| 18 | +#include "llvm/IR/Instructions.h" |
| 19 | +#include "llvm/IR/Module.h" |
| 20 | +#include "llvm/Support/Debug.h" |
| 21 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 22 | + |
| 23 | +#include <map> |
| 24 | +#include <string> |
| 25 | +#include <utility> |
| 26 | + |
| 27 | +using namespace llvm; |
| 28 | + |
| 29 | +#define DEBUG_TYPE "split-module-by-category" |
| 30 | + |
| 31 | +namespace { |
| 32 | + |
| 33 | +// A vector that contains a group of function with the same category. |
| 34 | +using EntryPointSet = SetVector<const Function *>; |
| 35 | + |
| 36 | +/// Represents a group of functions with one category. |
| 37 | +struct EntryPointGroup { |
| 38 | + int ID; |
| 39 | + EntryPointSet Functions; |
| 40 | + |
| 41 | + EntryPointGroup() = default; |
| 42 | + |
| 43 | + EntryPointGroup(int ID, EntryPointSet &&Functions = EntryPointSet()) |
| 44 | + : ID(ID), Functions(std::move(Functions)) {} |
| 45 | + |
| 46 | + void clear() { Functions.clear(); } |
| 47 | + |
| 48 | +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 49 | + LLVM_DUMP_METHOD void dump() const { |
| 50 | + constexpr size_t INDENT = 4; |
| 51 | + dbgs().indent(INDENT) << "ENTRY POINTS" |
| 52 | + << " " << ID << " {\n"; |
| 53 | + for (const Function *F : Functions) |
| 54 | + dbgs().indent(INDENT) << " " << F->getName() << "\n"; |
| 55 | + |
| 56 | + dbgs().indent(INDENT) << "}\n"; |
| 57 | + } |
| 58 | +#endif |
| 59 | +}; |
| 60 | + |
| 61 | +/// Annotates an llvm::Module with information necessary to perform and track |
| 62 | +/// the result of code (llvm::Module instances) splitting: |
| 63 | +/// - entry points group from the module. |
| 64 | +class ModuleDesc { |
| 65 | + std::unique_ptr<Module> M; |
| 66 | + EntryPointGroup EntryPoints; |
| 67 | + |
| 68 | +public: |
| 69 | + ModuleDesc(std::unique_ptr<Module> M, |
| 70 | + EntryPointGroup &&EntryPoints = EntryPointGroup()) |
| 71 | + : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { |
| 72 | + assert(this->M && "Module should be non-null"); |
| 73 | + } |
| 74 | + |
| 75 | + Module &getModule() { return *M; } |
| 76 | + const Module &getModule() const { return *M; } |
| 77 | + |
| 78 | + std::unique_ptr<Module> releaseModule() { |
| 79 | + EntryPoints.clear(); |
| 80 | + return std::move(M); |
| 81 | + } |
| 82 | + |
| 83 | +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 84 | + LLVM_DUMP_METHOD void dump() const { |
| 85 | + dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; |
| 86 | + EntryPoints.dump(); |
| 87 | + dbgs() << "}\n"; |
| 88 | + } |
| 89 | +#endif |
| 90 | +}; |
| 91 | + |
| 92 | +bool isKernel(const Function &F) { |
| 93 | + return F.getCallingConv() == CallingConv::SPIR_KERNEL || |
| 94 | + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || |
| 95 | + F.getCallingConv() == CallingConv::PTX_Kernel; |
| 96 | +} |
| 97 | + |
| 98 | +// Represents "dependency" or "use" graph of global objects (functions and |
| 99 | +// global variables) in a module. It is used during code split to |
| 100 | +// understand which global variables and functions (other than entry points) |
| 101 | +// should be included into a split module. |
| 102 | +// |
| 103 | +// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent |
| 104 | +// the fact that if "A" is included into a module, then "B" should be included |
| 105 | +// as well. |
| 106 | +// |
| 107 | +// Examples of dependencies which are represented in this graph: |
| 108 | +// - Function FA calls function FB |
| 109 | +// - Function FA uses global variable GA |
| 110 | +// - Global variable GA references (initialized with) function FB |
| 111 | +// - Function FA stores address of a function FB somewhere |
| 112 | +// |
| 113 | +// The following cases are treated as dependencies between global objects: |
| 114 | +// 1. Global object A is used by a global object B in any way (store, |
| 115 | +// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the |
| 116 | +// graph; |
| 117 | +// 2. function A performs an indirect call of a function with signature S and |
| 118 | +// there is a function B with signature S. "A" -> "B" edge will be added to |
| 119 | +// the graph; |
| 120 | +class DependencyGraph { |
| 121 | +public: |
| 122 | + using GlobalSet = SmallPtrSet<const GlobalValue *, 16>; |
| 123 | + |
| 124 | + DependencyGraph(const Module &M) { |
| 125 | + // Group functions by their signature to handle case (2) described above |
| 126 | + DenseMap<const FunctionType *, DependencyGraph::GlobalSet> |
| 127 | + FuncTypeToFuncsMap; |
| 128 | + for (const Function &F : M.functions()) { |
| 129 | + // Kernels can't be called (either directly or indirectly). |
| 130 | + if (isKernel(F)) |
| 131 | + continue; |
| 132 | + |
| 133 | + FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); |
| 134 | + } |
| 135 | + |
| 136 | + for (const Function &F : M.functions()) { |
| 137 | + // case (1), see comment above the class definition |
| 138 | + for (const Value *U : F.users()) |
| 139 | + addUserToGraphRecursively(cast<const User>(U), &F); |
| 140 | + |
| 141 | + // case (2), see comment above the class definition |
| 142 | + for (const Instruction &I : instructions(F)) { |
| 143 | + const CallBase *CB = dyn_cast<CallBase>(&I); |
| 144 | + if (!CB || !CB->isIndirectCall()) // Direct calls were handled above |
| 145 | + continue; |
| 146 | + |
| 147 | + const FunctionType *Signature = CB->getFunctionType(); |
| 148 | + GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature]; |
| 149 | + Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end()); |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + // And every global variable (but their handling is a bit simpler) |
| 154 | + for (const GlobalVariable &GV : M.globals()) |
| 155 | + for (const Value *U : GV.users()) |
| 156 | + addUserToGraphRecursively(cast<const User>(U), &GV); |
| 157 | + } |
| 158 | + |
| 159 | + iterator_range<GlobalSet::const_iterator> |
| 160 | + dependencies(const GlobalValue *Val) const { |
| 161 | + auto It = Graph.find(Val); |
| 162 | + return (It == Graph.end()) |
| 163 | + ? make_range(EmptySet.begin(), EmptySet.end()) |
| 164 | + : make_range(It->second.begin(), It->second.end()); |
| 165 | + } |
| 166 | + |
| 167 | +private: |
| 168 | + void addUserToGraphRecursively(const User *Root, const GlobalValue *V) { |
| 169 | + SmallVector<const User *, 8> WorkList; |
| 170 | + WorkList.push_back(Root); |
| 171 | + |
| 172 | + while (!WorkList.empty()) { |
| 173 | + const User *U = WorkList.pop_back_val(); |
| 174 | + if (const auto *I = dyn_cast<const Instruction>(U)) { |
| 175 | + const Function *UFunc = I->getFunction(); |
| 176 | + Graph[UFunc].insert(V); |
| 177 | + } else if (isa<const Constant>(U)) { |
| 178 | + if (const auto *GV = dyn_cast<const GlobalVariable>(U)) |
| 179 | + Graph[GV].insert(V); |
| 180 | + // This could be a global variable or some constant expression (like |
| 181 | + // bitcast or gep). We trace users of this constant further to reach |
| 182 | + // global objects they are used by and add them to the graph. |
| 183 | + for (const User *UU : U->users()) |
| 184 | + WorkList.push_back(UU); |
| 185 | + } else { |
| 186 | + llvm_unreachable("Unhandled type of function user"); |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + |
| 191 | + DenseMap<const GlobalValue *, GlobalSet> Graph; |
| 192 | + SmallPtrSet<const GlobalValue *, 1> EmptySet; |
| 193 | +}; |
| 194 | + |
| 195 | +void collectFunctionsAndGlobalVariablesToExtract( |
| 196 | + SetVector<const GlobalValue *> &GVs, const Module &M, |
| 197 | + const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { |
| 198 | + // We start with module entry points |
| 199 | + for (const Function *F : ModuleEntryPoints.Functions) |
| 200 | + GVs.insert(F); |
| 201 | + |
| 202 | + // Non-discardable global variables are also include into the initial set |
| 203 | + for (const GlobalVariable &GV : M.globals()) |
| 204 | + if (!GV.isDiscardableIfUnused()) |
| 205 | + GVs.insert(&GV); |
| 206 | + |
| 207 | + // GVs has SetVector type. This type inserts a value only if it is not yet |
| 208 | + // present there. So, recursion is not expected here. |
| 209 | + size_t Idx = 0; |
| 210 | + while (Idx < GVs.size()) { |
| 211 | + const GlobalValue *Obj = GVs[Idx++]; |
| 212 | + |
| 213 | + for (const GlobalValue *Dep : DG.dependencies(Obj)) { |
| 214 | + if (const auto *Func = dyn_cast<const Function>(Dep)) { |
| 215 | + if (!Func->isDeclaration()) |
| 216 | + GVs.insert(Func); |
| 217 | + } else { |
| 218 | + GVs.insert(Dep); // Global variables are added unconditionally |
| 219 | + } |
| 220 | + } |
| 221 | + } |
| 222 | +} |
| 223 | + |
| 224 | +ModuleDesc extractSubModule(const Module &M, |
| 225 | + const SetVector<const GlobalValue *> &GVs, |
| 226 | + EntryPointGroup &&ModuleEntryPoints) { |
| 227 | + ValueToValueMapTy VMap; |
| 228 | + // Clone definitions only for needed globals. Others will be added as |
| 229 | + // declarations and removed later. |
| 230 | + std::unique_ptr<Module> SubM = CloneModule( |
| 231 | + M, VMap, [&](const GlobalValue *GV) { return GVs.contains(GV); }); |
| 232 | + // Replace entry points with cloned ones. |
| 233 | + EntryPointSet NewEPs; |
| 234 | + const EntryPointSet &EPs = ModuleEntryPoints.Functions; |
| 235 | + llvm::for_each( |
| 236 | + EPs, [&](const Function *F) { NewEPs.insert(cast<Function>(VMap[F])); }); |
| 237 | + ModuleEntryPoints.Functions = std::move(NewEPs); |
| 238 | + return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; |
| 239 | +} |
| 240 | + |
| 241 | +// The function produces a copy of input LLVM IR module M with only those |
| 242 | +// functions and globals that can be called from entry points that are specified |
| 243 | +// in ModuleEntryPoints vector, in addition to the entry point functions. |
| 244 | +ModuleDesc extractCallGraph(const Module &M, |
| 245 | + EntryPointGroup &&ModuleEntryPoints, |
| 246 | + const DependencyGraph &DG) { |
| 247 | + SetVector<const GlobalValue *> GVs; |
| 248 | + collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG); |
| 249 | + |
| 250 | + ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints)); |
| 251 | + LLVM_DEBUG(SplitM.dump()); |
| 252 | + return SplitM; |
| 253 | +} |
| 254 | + |
| 255 | +using EntryPointGroupVec = SmallVector<EntryPointGroup>; |
| 256 | + |
| 257 | +/// Module Splitter. |
| 258 | +/// It gets a module and a collection of entry points groups. |
| 259 | +/// Each group specifies subset entry points from input module that should be |
| 260 | +/// included in a split module. |
| 261 | +class ModuleSplitter { |
| 262 | +private: |
| 263 | + std::unique_ptr<Module> M; |
| 264 | + EntryPointGroupVec Groups; |
| 265 | + DependencyGraph DG; |
| 266 | + |
| 267 | +private: |
| 268 | + EntryPointGroup drawEntryPointGroup() { |
| 269 | + assert(Groups.size() > 0 && "Reached end of entry point groups list."); |
| 270 | + EntryPointGroup Group = std::move(Groups.back()); |
| 271 | + Groups.pop_back(); |
| 272 | + return Group; |
| 273 | + } |
| 274 | + |
| 275 | +public: |
| 276 | + ModuleSplitter(std::unique_ptr<Module> Module, EntryPointGroupVec &&GroupVec) |
| 277 | + : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) { |
| 278 | + assert(!Groups.empty() && "Entry points groups collection is empty!"); |
| 279 | + } |
| 280 | + |
| 281 | + /// Gets next subsequence of entry points in an input module and provides |
| 282 | + /// split submodule containing these entry points and their dependencies. |
| 283 | + ModuleDesc getNextSplit() { |
| 284 | + return extractCallGraph(*M, drawEntryPointGroup(), DG); |
| 285 | + } |
| 286 | + |
| 287 | + /// Check that there are still submodules to split. |
| 288 | + bool hasMoreSplits() const { return Groups.size() > 0; } |
| 289 | +}; |
| 290 | + |
| 291 | +EntryPointGroupVec selectEntryPointGroups( |
| 292 | + const Module &M, function_ref<std::optional<int>(const Function &F)> EPC) { |
| 293 | + // std::map is used here to ensure stable ordering of entry point groups, |
| 294 | + // which is based on their contents, this greatly helps LIT tests |
| 295 | + // Note: EPC is allowed to return big identifiers. Therefore, we use |
| 296 | + // std::map + SmallVector approach here. |
| 297 | + std::map<int, EntryPointSet> EntryPointsMap; |
| 298 | + |
| 299 | + for (const auto &F : M.functions()) |
| 300 | + if (std::optional<int> Category = EPC(F); Category) |
| 301 | + EntryPointsMap[*Category].insert(&F); |
| 302 | + |
| 303 | + EntryPointGroupVec Groups; |
| 304 | + Groups.reserve(EntryPointsMap.size()); |
| 305 | + for (auto &[Key, EntryPoints] : EntryPointsMap) |
| 306 | + Groups.emplace_back(Key, std::move(EntryPoints)); |
| 307 | + |
| 308 | + return Groups; |
| 309 | +} |
| 310 | + |
| 311 | +} // namespace |
| 312 | + |
| 313 | +void llvm::splitModuleTransitiveFromEntryPoints( |
| 314 | + std::unique_ptr<Module> M, |
| 315 | + function_ref<std::optional<int>(const Function &F)> EntryPointCategorizer, |
| 316 | + function_ref<void(std::unique_ptr<Module> Part)> Callback) { |
| 317 | + EntryPointGroupVec Groups = selectEntryPointGroups(*M, EntryPointCategorizer); |
| 318 | + ModuleSplitter Splitter(std::move(M), std::move(Groups)); |
| 319 | + while (Splitter.hasMoreSplits()) { |
| 320 | + ModuleDesc MD = Splitter.getNextSplit(); |
| 321 | + Callback(std::move(MD.releaseModule())); |
| 322 | + } |
| 323 | +} |
0 commit comments