Skip to content

Commit 5e9ef96

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU] Expose IrCompiler IR & MC passes as separate methods.
PiperOrigin-RevId: 739112289
1 parent f58aac6 commit 5e9ef96

File tree

3 files changed

+76
-34
lines changed

3 files changed

+76
-34
lines changed

xla/backends/cpu/codegen/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ cc_library(
6060
deps = [
6161
":cpu_features",
6262
":polynomial_approximations",
63-
"//xla:debug_options_flags",
6463
"//xla:util",
6564
"//xla/service:hlo_module_config",
6665
"//xla/service/cpu:cpu_options",
@@ -70,6 +69,7 @@ cc_library(
7069
"@com_google_absl//absl/base:nullability",
7170
"@com_google_absl//absl/log",
7271
"@com_google_absl//absl/log:check",
72+
"@com_google_absl//absl/status",
7373
"@com_google_absl//absl/status:statusor",
7474
"@com_google_absl//absl/strings",
7575
"@com_google_absl//absl/strings:str_format",

xla/backends/cpu/codegen/ir_compiler.cc

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "absl/base/nullability.h"
2626
#include "absl/log/check.h"
2727
#include "absl/log/log.h"
28+
#include "absl/status/status.h"
2829
#include "absl/status/statusor.h"
2930
#include "absl/strings/str_format.h"
3031
#include "absl/strings/str_split.h"
@@ -47,7 +48,6 @@ limitations under the License.
4748
#include "llvm/Passes/StandardInstrumentations.h"
4849
#include "llvm/Support/Casting.h"
4950
#include "llvm/Support/CodeGen.h"
50-
#include "llvm/Support/Debug.h"
5151
#include "llvm/Support/Errc.h"
5252
#include "llvm/Support/Error.h"
5353
#include "llvm/Support/MemoryBuffer.h"
@@ -240,6 +240,42 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
240240
}
241241
}
242242

243+
if (llvm::Error ir_passes_error =
244+
RunIrPasses(module, target_machine->get())) {
245+
return ir_passes_error;
246+
}
247+
248+
VLOG(2) << "IR after optimizations";
249+
XLA_VLOG_LINES(2, llvm_ir::DumpToString(&module));
250+
251+
{ // Synchronize access to user-defined hooks.
252+
absl::MutexLock lock(&mutex_);
253+
if (hooks_.post_optimization) {
254+
hooks_.post_optimization(module);
255+
}
256+
}
257+
258+
std::unique_ptr<llvm::MemoryBuffer> mc_memory_buffer =
259+
EmitMachineCode(module, target_machine->get());
260+
261+
{ // Synchronize access to user-defined hooks.
262+
absl::MutexLock lock(&mutex_);
263+
if (hooks_.post_codegen) {
264+
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
265+
llvm::object::ObjectFile::createObjectFile(*mc_memory_buffer);
266+
if (obj_file) {
267+
hooks_.post_codegen(module, *obj_file.get());
268+
} else {
269+
LOG(WARNING) << "Could not convert memory buffer to object file";
270+
}
271+
}
272+
}
273+
274+
return std::move(mc_memory_buffer);
275+
}
276+
277+
llvm::Error IrCompiler::RunIrPasses(llvm::Module& module,
278+
llvm::TargetMachine* target_machine) const {
243279
llvm::PipelineTuningOptions pto = GetPipelineTuningOptions(module, options_);
244280
llvm::LoopAnalysisManager lam;
245281
llvm::FunctionAnalysisManager fam;
@@ -250,10 +286,10 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
250286
llvm::StandardInstrumentations si(module.getContext(), false);
251287
si.registerCallbacks(pic, &mam);
252288

253-
llvm::PassBuilder pb(target_machine->get(), pto, {}, &pic);
289+
llvm::PassBuilder pb(target_machine, pto, {}, &pic);
254290

255291
// Add the appropriate TargetLibraryInfo.
256-
llvm::Triple target_triple((*target_machine)->getTargetTriple());
292+
llvm::Triple target_triple(target_machine->getTargetTriple());
257293
auto target_library_info_impl =
258294
std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
259295
target_library_info_impl->addVectorizableFunctions(
@@ -281,51 +317,49 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
281317
pm.addPass(pb.buildPerModuleDefaultPipeline(opt_level));
282318
}
283319

284-
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
320+
{
321+
std::string error_string;
322+
llvm::raw_string_ostream error_stream(error_string);
323+
if (llvm::verifyModule(module, &error_stream)) {
324+
return llvm::make_error<llvm::StringError>(
325+
llvm::errc::invalid_argument,
326+
absl::StrFormat("Invalid LLVM IR before optimizations:\n%s",
327+
error_stream.str()));
328+
}
329+
}
285330

286331
pm.run(module, mam);
287332

288-
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
333+
{
334+
std::string error_string;
335+
llvm::raw_string_ostream error_stream(error_string);
336+
if (llvm::verifyModule(module, &error_stream)) {
337+
return llvm::make_error<llvm::StringError>(
338+
llvm::errc::invalid_argument,
339+
absl::StrFormat("Invalid LLVM IR after optimizations:\n%s",
340+
error_stream.str()));
341+
}
342+
}
289343

290344
RewriteToPolynomialApproximations(&module, options_.fast_math_flags);
291345

346+
return llvm::Error::success();
347+
}
348+
349+
std::unique_ptr<llvm::MemoryBuffer> IrCompiler::EmitMachineCode(
350+
llvm::Module& module, llvm::TargetMachine* target_machine) const {
292351
// Buffer for holding machine code prior to constructing the ObjectFile.
293352
llvm::SmallVector<char, 0> mc_stream_buffer;
294353
llvm::raw_svector_ostream ostream(mc_stream_buffer);
295354

296-
VLOG(2) << "IR after optimizations";
297-
XLA_VLOG_LINES(2, llvm_ir::DumpToString(&module));
298-
299-
{ // Synchronize access to user-defined hooks.
300-
absl::MutexLock lock(&mutex_);
301-
if (hooks_.post_optimization) {
302-
hooks_.post_optimization(module);
303-
}
304-
}
305-
306355
// Generate code.
307356
llvm::MCContext* mc_context;
308357
llvm::legacy::PassManager codegen_passes;
309-
(*target_machine)->addPassesToEmitMC(codegen_passes, mc_context, ostream);
358+
target_machine->addPassesToEmitMC(codegen_passes, mc_context, ostream);
310359
codegen_passes.run(module);
311360

312-
std::unique_ptr<llvm::MemoryBuffer> mc_memory_buffer(
313-
new llvm::SmallVectorMemoryBuffer(std::move(mc_stream_buffer)));
314-
315-
{ // Synchronize access to user-defined hooks.
316-
absl::MutexLock lock(&mutex_);
317-
if (hooks_.post_codegen) {
318-
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
319-
llvm::object::ObjectFile::createObjectFile(*mc_memory_buffer);
320-
if (obj_file) {
321-
hooks_.post_codegen(module, *obj_file.get());
322-
} else {
323-
LOG(WARNING) << "Could not convert memory buffer to object file";
324-
}
325-
}
326-
}
327-
328-
return std::move(mc_memory_buffer);
361+
return std::make_unique<llvm::SmallVectorMemoryBuffer>(
362+
std::move(mc_stream_buffer));
329363
}
330364

331365
llvm::CodeGenOptLevel IrCompiler::GetCodeGenOptLevel(

xla/backends/cpu/codegen/ir_compiler.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ class IrCompiler : public llvm::orc::IRCompileLayer::IRCompiler {
111111
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
112112
llvm::Module& module) final;
113113

114+
// Runs the IR passes on the given module.
115+
llvm::Error RunIrPasses(llvm::Module& module,
116+
llvm::TargetMachine* target_machine) const;
117+
118+
// Emits machine code for the given module.
119+
std::unique_ptr<llvm::MemoryBuffer> EmitMachineCode(
120+
llvm::Module& module, llvm::TargetMachine* target_machine) const;
121+
114122
static llvm::CodeGenOptLevel GetCodeGenOptLevel(
115123
const HloModuleConfig& module_config);
116124

0 commit comments

Comments
 (0)