Skip to content

Commit d9dadfd

Browse files
authored
Refactor ModuleToObject to offer more flexibility to subclass (NFC)
Some specific implementation of the offload may want more customization, and even avoid using LLVM in-tree to dispatch the ISA translation to a custom solution. This refactoring makes it possible for such implementation to work without even configuring the target backend in LLVM. Reviewers: fabianmcg Reviewed By: fabianmcg Pull Request: #71165
1 parent fcc26ba commit d9dadfd

File tree

9 files changed

+109
-91
lines changed

9 files changed

+109
-91
lines changed

mlir/include/mlir/Target/LLVM/ModuleToObject.h

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ModuleToObject {
3131
public:
3232
ModuleToObject(Operation &module, StringRef triple, StringRef chip,
3333
StringRef features = {}, int optLevel = 3);
34-
virtual ~ModuleToObject() = default;
34+
virtual ~ModuleToObject();
3535

3636
/// Returns the operation being serialized.
3737
Operation &getOperation();
@@ -42,44 +42,43 @@ class ModuleToObject {
4242
protected:
4343
// Hooks to be implemented by derived classes.
4444

45+
/// Hook for computing the Datalayout
46+
virtual void setDataLayoutAndTriple(llvm::Module &module);
47+
4548
/// Hook for loading bitcode files, returns std::nullopt on failure.
4649
virtual std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
47-
loadBitcodeFiles(llvm::Module &module, llvm::TargetMachine &targetMachine) {
50+
loadBitcodeFiles(llvm::Module &module) {
4851
return SmallVector<std::unique_ptr<llvm::Module>>();
4952
}
5053

5154
/// Hook for performing additional actions on a loaded bitcode file.
52-
virtual LogicalResult handleBitcodeFile(llvm::Module &module,
53-
llvm::TargetMachine &targetMachine) {
55+
virtual LogicalResult handleBitcodeFile(llvm::Module &module) {
5456
return success();
5557
}
5658

5759
/// Hook for performing additional actions on the llvmModule pre linking.
58-
virtual void handleModulePreLink(llvm::Module &module,
59-
llvm::TargetMachine &targetMachine) {}
60+
virtual void handleModulePreLink(llvm::Module &module) {}
6061

6162
/// Hook for performing additional actions on the llvmModule post linking.
62-
virtual void handleModulePostLink(llvm::Module &module,
63-
llvm::TargetMachine &targetMachine) {}
63+
virtual void handleModulePostLink(llvm::Module &module) {}
6464

6565
/// Serializes the LLVM IR bitcode to an object file, by default it serializes
6666
/// to LLVM bitcode.
6767
virtual std::optional<SmallVector<char, 0>>
68-
moduleToObject(llvm::Module &llvmModule, llvm::TargetMachine &targetMachine);
68+
moduleToObject(llvm::Module &llvmModule);
6969

7070
protected:
7171
/// Create the target machine based on the target triple and chip.
72-
std::unique_ptr<llvm::TargetMachine> createTargetMachine();
72+
/// This can fail if the target is not available.
73+
std::optional<llvm::TargetMachine *> getOrCreateTargetMachine();
7374

7475
/// Loads a bitcode file from path.
75-
std::unique_ptr<llvm::Module>
76-
loadBitcodeFile(llvm::LLVMContext &context,
77-
llvm::TargetMachine &targetMachine, StringRef path);
76+
std::unique_ptr<llvm::Module> loadBitcodeFile(llvm::LLVMContext &context,
77+
StringRef path);
7878

7979
/// Loads multiple bitcode files.
8080
LogicalResult loadBitcodeFilesFromList(
81-
llvm::LLVMContext &context, llvm::TargetMachine &targetMachine,
82-
ArrayRef<std::string> fileList,
81+
llvm::LLVMContext &context, ArrayRef<std::string> fileList,
8382
SmallVector<std::unique_ptr<llvm::Module>> &llvmModules,
8483
bool failureOnError = true);
8584

@@ -92,8 +91,7 @@ class ModuleToObject {
9291
SmallVector<std::unique_ptr<llvm::Module>> &&libs);
9392

9493
/// Optimize the module.
95-
LogicalResult optimizeModule(llvm::Module &module,
96-
llvm::TargetMachine &targetMachine, int optL);
94+
virtual LogicalResult optimizeModule(llvm::Module &module, int optL);
9795

9896
/// Utility function for translating to ISA, returns `std::nullopt` on
9997
/// failure.
@@ -115,6 +113,11 @@ class ModuleToObject {
115113

116114
/// Optimization level.
117115
int optLevel;
116+
117+
private:
118+
/// The TargetMachine created for the given Triple, if available.
119+
/// Accessible through `getOrCreateTargetMachine()`.
120+
std::unique_ptr<llvm::TargetMachine> targetMachine;
118121
};
119122
} // namespace LLVM
120123
} // namespace mlir

mlir/include/mlir/Target/LLVM/NVVM/Utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
5555

5656
/// Loads the bitcode files in `fileList`.
5757
virtual std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
58-
loadBitcodeFiles(llvm::Module &module,
59-
llvm::TargetMachine &targetMachine) override;
58+
loadBitcodeFiles(llvm::Module &module) override;
6059

6160
protected:
6261
/// NVVM target attribute.

mlir/include/mlir/Target/LLVM/ROCDL/Utils.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,13 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
5454

5555
/// Loads the bitcode files in `fileList`.
5656
virtual std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
57-
loadBitcodeFiles(llvm::Module &module,
58-
llvm::TargetMachine &targetMachine) override;
57+
loadBitcodeFiles(llvm::Module &module) override;
5958

6059
/// Adds `oclc` control variables to the LLVM module.
61-
void handleModulePreLink(llvm::Module &module,
62-
llvm::TargetMachine &targetMachine) override;
60+
void handleModulePreLink(llvm::Module &module) override;
6361

6462
/// Removes unnecessary metadata from the loaded bitcode files.
65-
LogicalResult handleBitcodeFile(llvm::Module &module,
66-
llvm::TargetMachine &targetMachine) override;
63+
LogicalResult handleBitcodeFile(llvm::Module &module) override;
6764

6865
protected:
6966
/// Appends the paths of common ROCm device libraries to `libs`.

mlir/lib/Conversion/GPUCommon/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if (MLIR_ENABLE_CUDA_CONVERSIONS)
1+
if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
22
set(NVPTX_LIBS
33
NVPTXCodeGen
44
NVPTXDesc

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if (MLIR_ENABLE_CUDA_CONVERSIONS)
1+
if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
22
set(NVPTX_LIBS
33
NVPTXCodeGen
44
NVPTXDesc

mlir/lib/Target/LLVM/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ add_mlir_library(MLIRTargetLLVM
2121
MLIRTargetLLVMIRExport
2222
)
2323

24-
if (MLIR_ENABLE_CUDA_CONVERSIONS)
24+
if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
2525
set(NVPTX_LIBS
2626
NVPTXCodeGen
2727
NVPTXDesc

mlir/lib/Target/LLVM/ModuleToObject.cpp

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,34 @@ ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
3939
: module(module), triple(triple), chip(chip), features(features),
4040
optLevel(optLevel) {}
4141

42+
ModuleToObject::~ModuleToObject() = default;
43+
4244
Operation &ModuleToObject::getOperation() { return module; }
4345

44-
std::unique_ptr<llvm::TargetMachine> ModuleToObject::createTargetMachine() {
45-
std::string error;
46+
std::optional<llvm::TargetMachine *>
47+
ModuleToObject::getOrCreateTargetMachine() {
48+
if (targetMachine)
49+
return targetMachine.get();
4650
// Load the target.
51+
std::string error;
4752
const llvm::Target *target =
4853
llvm::TargetRegistry::lookupTarget(triple, error);
4954
if (!target) {
50-
getOperation().emitError() << "Failed to lookup target: " << error;
51-
return {};
55+
getOperation().emitError()
56+
<< "Failed to lookup target for triple '" << triple << "' " << error;
57+
return std::nullopt;
5258
}
5359

5460
// Create the target machine using the target.
55-
llvm::TargetMachine *machine =
56-
target->createTargetMachine(triple, chip, features, {}, {});
57-
if (!machine) {
58-
getOperation().emitError() << "Failed to create the target machine.";
59-
return {};
60-
}
61-
return std::unique_ptr<llvm::TargetMachine>{machine};
61+
targetMachine.reset(
62+
target->createTargetMachine(triple, chip, features, {}, {}));
63+
if (!targetMachine)
64+
return std::nullopt;
65+
return targetMachine.get();
6266
}
6367

6468
std::unique_ptr<llvm::Module>
65-
ModuleToObject::loadBitcodeFile(llvm::LLVMContext &context,
66-
llvm::TargetMachine &targetMachine,
67-
StringRef path) {
69+
ModuleToObject::loadBitcodeFile(llvm::LLVMContext &context, StringRef path) {
6870
llvm::SMDiagnostic error;
6971
std::unique_ptr<llvm::Module> library =
7072
llvm::getLazyIRFileModule(path, error, context);
@@ -73,15 +75,14 @@ ModuleToObject::loadBitcodeFile(llvm::LLVMContext &context,
7375
<< ", error: " << error.getMessage();
7476
return nullptr;
7577
}
76-
if (failed(handleBitcodeFile(*library, targetMachine))) {
78+
if (failed(handleBitcodeFile(*library))) {
7779
return nullptr;
7880
}
7981
return library;
8082
}
8183

8284
LogicalResult ModuleToObject::loadBitcodeFilesFromList(
83-
llvm::LLVMContext &context, llvm::TargetMachine &targetMachine,
84-
ArrayRef<std::string> fileList,
85+
llvm::LLVMContext &context, ArrayRef<std::string> fileList,
8586
SmallVector<std::unique_ptr<llvm::Module>> &llvmModules,
8687
bool failureOnError) {
8788
for (const std::string &str : fileList) {
@@ -93,7 +94,7 @@ LogicalResult ModuleToObject::loadBitcodeFilesFromList(
9394
return failure();
9495
}
9596
// Load the file or abort on error.
96-
if (auto bcFile = loadBitcodeFile(context, targetMachine, pathRef))
97+
if (auto bcFile = loadBitcodeFile(context, pathRef))
9798
llvmModules.push_back(std::move(bcFile));
9899
else if (failureOnError)
99100
return failure();
@@ -137,16 +138,22 @@ ModuleToObject::linkFiles(llvm::Module &module,
137138
}
138139

139140
LogicalResult ModuleToObject::optimizeModule(llvm::Module &module,
140-
llvm::TargetMachine &targetMachine,
141+
141142
int optLevel) {
142143
if (optLevel < 0 || optLevel > 3)
143144
return getOperation().emitError()
144145
<< "Invalid optimization level: " << optLevel << ".";
145146

146-
targetMachine.setOptLevel(static_cast<llvm::CodeGenOptLevel>(optLevel));
147+
std::optional<llvm::TargetMachine *> targetMachine =
148+
getOrCreateTargetMachine();
149+
if (!targetMachine)
150+
return getOperation().emitError()
151+
<< "Target Machine unavailable for triple " << triple
152+
<< ", can't optimize with LLVM\n";
153+
(*targetMachine)->setOptLevel(static_cast<llvm::CodeGenOptLevel>(optLevel));
147154

148155
auto transformer =
149-
makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
156+
makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, *targetMachine);
150157
auto error = transformer(&module);
151158
if (error) {
152159
InFlightDiagnostic mlirError = getOperation().emitError();
@@ -178,9 +185,19 @@ ModuleToObject::translateToISA(llvm::Module &llvmModule,
178185
return stream.str();
179186
}
180187

188+
void ModuleToObject::setDataLayoutAndTriple(llvm::Module &module) {
189+
// Create the target machine.
190+
std::optional<llvm::TargetMachine *> targetMachine =
191+
getOrCreateTargetMachine();
192+
if (targetMachine) {
193+
// Set the data layout and target triple of the module.
194+
module.setDataLayout((*targetMachine)->createDataLayout());
195+
module.setTargetTriple((*targetMachine)->getTargetTriple().getTriple());
196+
}
197+
}
198+
181199
std::optional<SmallVector<char, 0>>
182-
ModuleToObject::moduleToObject(llvm::Module &llvmModule,
183-
llvm::TargetMachine &targetMachine) {
200+
ModuleToObject::moduleToObject(llvm::Module &llvmModule) {
184201
SmallVector<char, 0> binaryData;
185202
// Write the LLVM module bitcode to a buffer.
186203
llvm::raw_svector_ostream outputStream(binaryData);
@@ -196,32 +213,24 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
196213
getOperation().emitError() << "Failed creating the llvm::Module.";
197214
return std::nullopt;
198215
}
199-
200-
// Create the target machine.
201-
std::unique_ptr<llvm::TargetMachine> targetMachine = createTargetMachine();
202-
if (!targetMachine)
203-
return std::nullopt;
204-
205-
// Set the data layout and target triple of the module.
206-
llvmModule->setDataLayout(targetMachine->createDataLayout());
207-
llvmModule->setTargetTriple(targetMachine->getTargetTriple().getTriple());
216+
setDataLayoutAndTriple(*llvmModule);
208217

209218
// Link bitcode files.
210-
handleModulePreLink(*llvmModule, *targetMachine);
219+
handleModulePreLink(*llvmModule);
211220
{
212-
auto libs = loadBitcodeFiles(*llvmModule, *targetMachine);
221+
auto libs = loadBitcodeFiles(*llvmModule);
213222
if (!libs)
214223
return std::nullopt;
215224
if (!libs->empty())
216225
if (failed(linkFiles(*llvmModule, std::move(*libs))))
217226
return std::nullopt;
218-
handleModulePostLink(*llvmModule, *targetMachine);
227+
handleModulePostLink(*llvmModule);
219228
}
220229

221230
// Optimize the module.
222-
if (failed(optimizeModule(*llvmModule, *targetMachine, optLevel)))
231+
if (failed(optimizeModule(*llvmModule, optLevel)))
223232
return std::nullopt;
224233

225234
// Return the serialized object.
226-
return moduleToObject(*llvmModule, *targetMachine);
235+
return moduleToObject(*llvmModule);
227236
}

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void SerializeGPUModuleBase::init() {
106106
static llvm::once_flag initializeBackendOnce;
107107
llvm::call_once(initializeBackendOnce, []() {
108108
// If the `NVPTX` LLVM target was built, initialize it.
109-
#if MLIR_CUDA_CONVERSIONS_ENABLED == 1
109+
#if LLVM_HAS_NVPTX_TARGET
110110
LLVMInitializeNVPTXTarget();
111111
LLVMInitializeNVPTXTargetInfo();
112112
LLVMInitializeNVPTXTargetMC();
@@ -148,11 +148,10 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs() {
148148
}
149149

150150
std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
151-
SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module,
152-
llvm::TargetMachine &targetMachine) {
151+
SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
153152
SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
154-
if (failed(loadBitcodeFilesFromList(module.getContext(), targetMachine,
155-
fileList, bcFiles, true)))
153+
if (failed(loadBitcodeFilesFromList(module.getContext(), fileList, bcFiles,
154+
true)))
156155
return std::nullopt;
157156
return std::move(bcFiles);
158157
}
@@ -175,8 +174,7 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
175174
compileToBinaryNVPTX(const std::string &ptxCode);
176175

177176
std::optional<SmallVector<char, 0>>
178-
moduleToObject(llvm::Module &llvmModule,
179-
llvm::TargetMachine &targetMachine) override;
177+
moduleToObject(llvm::Module &llvmModule) override;
180178

181179
private:
182180
using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>;
@@ -514,8 +512,7 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
514512
#endif // MLIR_NVPTXCOMPILER_ENABLED == 1
515513

516514
std::optional<SmallVector<char, 0>>
517-
NVPTXSerializer::moduleToObject(llvm::Module &llvmModule,
518-
llvm::TargetMachine &targetMachine) {
515+
NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
519516
// Return LLVM IR if the compilation target is offload.
520517
#define DEBUG_TYPE "serialize-to-llvm"
521518
LLVM_DEBUG({
@@ -526,11 +523,18 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule,
526523
});
527524
#undef DEBUG_TYPE
528525
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
529-
return SerializeGPUModuleBase::moduleToObject(llvmModule, targetMachine);
526+
return SerializeGPUModuleBase::moduleToObject(llvmModule);
530527

531528
// Emit PTX code.
529+
std::optional<llvm::TargetMachine *> targetMachine =
530+
getOrCreateTargetMachine();
531+
if (!targetMachine) {
532+
getOperation().emitError() << "Target Machine unavailable for triple "
533+
<< triple << ", can't optimize with LLVM\n";
534+
return std::nullopt;
535+
}
532536
std::optional<std::string> serializedISA =
533-
translateToISA(llvmModule, targetMachine);
537+
translateToISA(llvmModule, **targetMachine);
534538
if (!serializedISA) {
535539
getOperation().emitError() << "Failed translating the module to ISA.";
536540
return std::nullopt;

0 commit comments

Comments
 (0)