@@ -25,6 +25,7 @@ limitations under the License.
2525#include " absl/base/nullability.h"
2626#include " absl/log/check.h"
2727#include " absl/log/log.h"
28+ #include " absl/status/status.h"
2829#include " absl/status/statusor.h"
2930#include " absl/strings/str_format.h"
3031#include " absl/strings/str_split.h"
@@ -47,7 +48,6 @@ limitations under the License.
4748#include " llvm/Passes/StandardInstrumentations.h"
4849#include " llvm/Support/Casting.h"
4950#include " llvm/Support/CodeGen.h"
50- #include " llvm/Support/Debug.h"
5151#include " llvm/Support/Errc.h"
5252#include " llvm/Support/Error.h"
5353#include " llvm/Support/MemoryBuffer.h"
@@ -240,6 +240,42 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
240240 }
241241 }
242242
243+ if (llvm::Error ir_passes_error =
244+ RunIrPasses (module , target_machine->get ())) {
245+ return ir_passes_error;
246+ }
247+
248+ VLOG (2 ) << " IR after optimizations" ;
249+ XLA_VLOG_LINES (2 , llvm_ir::DumpToString (&module ));
250+
251+ { // Synchronize access to user-defined hooks.
252+ absl::MutexLock lock (&mutex_);
253+ if (hooks_.post_optimization ) {
254+ hooks_.post_optimization (module );
255+ }
256+ }
257+
258+ std::unique_ptr<llvm::MemoryBuffer> mc_memory_buffer =
259+ EmitMachineCode (module , target_machine->get ());
260+
261+ { // Synchronize access to user-defined hooks.
262+ absl::MutexLock lock (&mutex_);
263+ if (hooks_.post_codegen ) {
264+ llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
265+ llvm::object::ObjectFile::createObjectFile (*mc_memory_buffer);
266+ if (obj_file) {
267+ hooks_.post_codegen (module , *obj_file.get ());
268+ } else {
269+ LOG (WARNING) << " Could not convert memory buffer to object file" ;
270+ }
271+ }
272+ }
273+
274+ return std::move (mc_memory_buffer);
275+ }
276+
277+ llvm::Error IrCompiler::RunIrPasses (llvm::Module& module ,
278+ llvm::TargetMachine* target_machine) const {
243279 llvm::PipelineTuningOptions pto = GetPipelineTuningOptions (module , options_);
244280 llvm::LoopAnalysisManager lam;
245281 llvm::FunctionAnalysisManager fam;
@@ -250,10 +286,10 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
250286 llvm::StandardInstrumentations si (module .getContext (), false );
251287 si.registerCallbacks (pic, &mam);
252288
253- llvm::PassBuilder pb (target_machine-> get () , pto, {}, &pic);
289+ llvm::PassBuilder pb (target_machine, pto, {}, &pic);
254290
255291 // Add the appropriate TargetLibraryInfo.
256- llvm::Triple target_triple ((* target_machine) ->getTargetTriple ());
292+ llvm::Triple target_triple (target_machine->getTargetTriple ());
257293 auto target_library_info_impl =
258294 std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
259295 target_library_info_impl->addVectorizableFunctions (
@@ -281,51 +317,49 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
281317 pm.addPass (pb.buildPerModuleDefaultPipeline (opt_level));
282318 }
283319
284- CHECK (!llvm::verifyModule (module , &llvm::dbgs ()));
320+ {
321+ std::string error_string;
322+ llvm::raw_string_ostream error_stream (error_string);
323+ if (llvm::verifyModule (module , &error_stream)) {
324+ return llvm::make_error<llvm::StringError>(
325+ llvm::errc::invalid_argument,
326+ absl::StrFormat (" Invalid LLVM IR before optimizations:\n %s" ,
327+ error_stream.str ()));
328+ }
329+ }
285330
286331 pm.run (module , mam);
287332
288- CHECK (!llvm::verifyModule (module , &llvm::dbgs ()));
333+ {
334+ std::string error_string;
335+ llvm::raw_string_ostream error_stream (error_string);
336+ if (llvm::verifyModule (module , &error_stream)) {
337+ return llvm::make_error<llvm::StringError>(
338+ llvm::errc::invalid_argument,
339+ absl::StrFormat (" Invalid LLVM IR after optimizations:\n %s" ,
340+ error_stream.str ()));
341+ }
342+ }
289343
290344 RewriteToPolynomialApproximations (&module , options_.fast_math_flags );
291345
346+ return llvm::Error::success ();
347+ }
348+
349+ std::unique_ptr<llvm::MemoryBuffer> IrCompiler::EmitMachineCode (
350+ llvm::Module& module , llvm::TargetMachine* target_machine) const {
292351 // Buffer for holding machine code prior to constructing the ObjectFile.
293352 llvm::SmallVector<char , 0 > mc_stream_buffer;
294353 llvm::raw_svector_ostream ostream (mc_stream_buffer);
295354
296- VLOG (2 ) << " IR after optimizations" ;
297- XLA_VLOG_LINES (2 , llvm_ir::DumpToString (&module ));
298-
299- { // Synchronize access to user-defined hooks.
300- absl::MutexLock lock (&mutex_);
301- if (hooks_.post_optimization ) {
302- hooks_.post_optimization (module );
303- }
304- }
305-
306355 // Generate code.
307356 llvm::MCContext* mc_context;
308357 llvm::legacy::PassManager codegen_passes;
309- (* target_machine) ->addPassesToEmitMC (codegen_passes, mc_context, ostream);
358+ target_machine->addPassesToEmitMC (codegen_passes, mc_context, ostream);
310359 codegen_passes.run (module );
311360
312- std::unique_ptr<llvm::MemoryBuffer> mc_memory_buffer (
313- new llvm::SmallVectorMemoryBuffer (std::move (mc_stream_buffer)));
314-
315- { // Synchronize access to user-defined hooks.
316- absl::MutexLock lock (&mutex_);
317- if (hooks_.post_codegen ) {
318- llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
319- llvm::object::ObjectFile::createObjectFile (*mc_memory_buffer);
320- if (obj_file) {
321- hooks_.post_codegen (module , *obj_file.get ());
322- } else {
323- LOG (WARNING) << " Could not convert memory buffer to object file" ;
324- }
325- }
326- }
327-
328- return std::move (mc_memory_buffer);
361+ return std::make_unique<llvm::SmallVectorMemoryBuffer>(
362+ std::move (mc_stream_buffer));
329363}
330364
331365llvm::CodeGenOptLevel IrCompiler::GetCodeGenOptLevel (
0 commit comments