4646#include " llvm/Support/InitLLVM.h"
4747#include " llvm/Support/SourceMgr.h"
4848
49+ #include < atomic>
50+ #include < cassert>
4951#include < chrono>
5052#include < cstdlib>
53+ #include < future>
54+ #include < mutex>
5155#include < thread>
5256
5357// Utilities to allocate buffers
@@ -129,6 +133,11 @@ static llvm::cl::opt<std::string> benchmarkConfig(
129133 " Run benchmark with specific perf config only (skip tuning)" ),
130134 llvm::cl::value_desc(" perf config string" ), llvm::cl::init(" " ));
131135
136+ static llvm::cl::opt<unsigned > numCompileThreads (
137+ " num-compile-threads" ,
138+ llvm::cl::desc (" Number of parallel compilation threads (0 = auto)" ),
139+ llvm::cl::value_desc(" thread count" ), llvm::cl::init(0 ));
140+
132141// Ripped out of JitRunner.cpp
133142static OwningOpRef<ModuleOp> parseMLIRInput (StringRef inputFilename,
134143 MLIRContext *context) {
@@ -255,6 +264,20 @@ struct BenchmarkParams {
255264 bool showStats;
256265};
257266
267+ enum class CompilationStatus {
268+ NotApplicable, // Config not applicable for this kernel
269+ CompilationFailed, // Config applicable but compilation failed
270+ Success // Successfully compiled
271+ };
272+
273+ struct CompilationResult {
274+ SmallString<64 > perfConfig;
275+ CompilationStatus status = CompilationStatus::NotApplicable;
276+ SmallVector<std::string> hipModules;
277+ SmallVector<uint32_t > blockSizes;
278+ SmallVector<uint32_t > gridSizes;
279+ };
280+
258281// In order to match rocprof, returns time in nanoseconds
259282static FailureOr<double >
260283benchmarkKernels (ArrayRef<std::string> binaries,
@@ -449,22 +472,16 @@ static LogicalResult runTuningLoop(ModuleOp source) {
449472 bufferLengths.push_back (sizeInBits / 8 );
450473 }
451474
452- // 2. Set up pipelines. Do this only once to save on construction cost.
453- MLIRContext *ctx = source->getContext ();
454- PassManager applicability (source->getName (), PassManager::Nesting::Implicit);
455- PassManager compilation (source->getName (), PassManager::Nesting::Implicit);
456-
475+ // 2. Set up compilation options (shared across all threads)
457476 rock::KernelOptions applicabilityOpts;
458477 applicabilityOpts.enableApplicability = true ;
459478 applicabilityOpts.enableFusion = true ;
460479 applicabilityOpts.tuningFallback = false ;
461- rock::buildKernelPipeline (applicability, applicabilityOpts);
462480
463481 rock::KernelOptions compilationKernOpts;
464482 compilationKernOpts.enableApplicability = false ;
465483 compilationKernOpts.enableFusion = true ;
466484 compilationKernOpts.tuningFallback = false ;
467- rock::buildKernelPipeline (compilation, compilationKernOpts);
468485
469486 RocmDeviceName deviceName;
470487 StringRef archName =
@@ -478,12 +495,6 @@ static LogicalResult runTuningLoop(ModuleOp source) {
478495 backendOpts.features = backendFeatures;
479496 backendOpts.optLevel = 3 ;
480497 backendOpts.suppressDiagnostic = true ;
481- rock::buildBackendPipeline (compilation, backendOpts);
482-
483- // Now that we're in the kernel execution zone, turn off error messages
484- // Register a handler that swallows all diagnostic print
485- DiagnosticEngine &engine = ctx->getDiagEngine ();
486- engine.registerHandler ([](Diagnostic &diag) {});
487498
488499 // 3. Initialize host buffers and allocate device buffers
489500 std::vector<void *> hostBuffers;
@@ -498,20 +509,7 @@ static LogicalResult runTuningLoop(ModuleOp source) {
498509 gpuBuffers.push_back (gpuBuffer);
499510 }
500511
501- auto copyIR = [&](ModuleOp source,
502- StringAttr perfConfigAttr) -> OwningOpRef<ModuleOp> {
503- OwningOpRef<ModuleOp> copy = cast<ModuleOp>(source->clone ());
504-
505- copy->walk ([&perfConfigAttr](rock::RockGemmWrapperInterface op) {
506- op->setAttr (" perf_config" , perfConfigAttr);
507- });
508- copy->walk ([&perfConfigAttr](rock::RockGemmGemmWrapperInterface op) {
509- op->setAttr (" perf_config" , perfConfigAttr);
510- });
511- return copy;
512- };
513-
514- // 4. Actually tune
512+ // 4. Collect perf configs to compile
515513 std::vector<SmallString<64 >> configs;
516514 if (!benchmarkConfig.empty ()) {
517515 // Benchmark mode - just one config
@@ -540,65 +538,188 @@ static LogicalResult runTuningLoop(ModuleOp source) {
540538 useMedian, trimPercent,
541539 sleepUs, showStats};
542540
543- for (const auto &perfConfig : configs) {
544- llvm::outs () << perfConfig << " \t " ;
545- OwningOpRef<ModuleOp> tuneCopy = cast<ModuleOp>(source->clone ());
546- StringAttr perfConfigAttr = StringAttr::get (ctx, perfConfig);
547-
548- OwningOpRef<ModuleOp> applicabilityCopy = copyIR (source, perfConfigAttr);
549- if (!rock::isModuleFusible (applicabilityCopy.get (), perfConfig)) {
550- llvm::outs () << " N/A\n " ;
551- continue ;
541+ // Determine number of parallel threads
542+ unsigned numThreads = (numCompileThreads > 0 )
543+ ? numCompileThreads
544+ : std::thread::hardware_concurrency ();
545+ if (numThreads == 0 )
546+ numThreads = 4 ; // fallback
547+
548+ // Don't create more threads than configs to compile
549+ numThreads = std::min (numThreads, static_cast <unsigned >(configs.size ()));
550+
551+ // Serialize source module once (shared by all threads for cloning)
552+ std::string sourceModuleStr;
553+ llvm::raw_string_ostream sourceOs (sourceModuleStr);
554+ source->print (sourceOs);
555+ sourceOs.flush ();
556+
557+ // Parallel compilation phase
558+ std::vector<CompilationResult> compilationResults (configs.size ());
559+ std::mutex outputMutex; // For thread-safe console output
560+ std::atomic<bool > compilationFailed{
561+ false }; // Flag to signal early termination
562+
563+ auto compileConfig = [&](size_t idx) -> CompilationResult {
564+ CompilationResult result;
565+ result.perfConfig = configs[idx];
566+
567+ // Each thread needs its own context and pass managers for thread-safety
568+ DialectRegistry threadRegistry;
569+ registerRocMLIRDialects (threadRegistry);
570+ MLIRContext threadCtx (threadRegistry);
571+ threadCtx.getDiagEngine ().registerHandler ([](Diagnostic &diag) {});
572+
573+ // Parse the serialized module in this thread's context
574+ OwningOpRef<ModuleOp> threadSource =
575+ parseSourceString<ModuleOp>(sourceModuleStr, &threadCtx);
576+ if (!threadSource)
577+ return result;
578+
579+ // Set up pipelines for this thread
580+ PassManager threadApplicability (&threadCtx,
581+ PassManager::getAnyOpAnchorName (),
582+ PassManager::Nesting::Implicit);
583+ PassManager threadCompilation (&threadCtx, PassManager::getAnyOpAnchorName (),
584+ PassManager::Nesting::Implicit);
585+
586+ rock::buildKernelPipeline (threadApplicability, applicabilityOpts);
587+ rock::buildKernelPipeline (threadCompilation, compilationKernOpts);
588+ rock::buildBackendPipeline (threadCompilation, backendOpts);
589+
590+ StringAttr perfConfigAttr = StringAttr::get (&threadCtx, result.perfConfig );
591+
592+ // Helper to copy IR with perf config set
593+ auto copyIRThread = [&](ModuleOp src,
594+ StringAttr attr) -> OwningOpRef<ModuleOp> {
595+ OwningOpRef<ModuleOp> copy = cast<ModuleOp>(src->clone ());
596+ copy->walk ([&attr](rock::RockGemmWrapperInterface op) {
597+ op->setAttr (" perf_config" , attr);
598+ });
599+ copy->walk ([&attr](rock::RockGemmGemmWrapperInterface op) {
600+ op->setAttr (" perf_config" , attr);
601+ });
602+ return copy;
603+ };
604+
605+ // Applicability check
606+ OwningOpRef<ModuleOp> applicabilityCopy =
607+ copyIRThread (threadSource.get (), perfConfigAttr);
608+ if (!rock::isModuleFusible (applicabilityCopy.get (), result.perfConfig )) {
609+ result.status = CompilationStatus::NotApplicable;
610+ return result;
552611 }
553612
554- if (failed (applicability .run (applicabilityCopy.get ()))) {
555- llvm::outs () << " N/A \n " ;
556- continue ;
613+ if (failed (threadApplicability .run (applicabilityCopy.get ()))) {
614+ result. status = CompilationStatus::NotApplicable ;
615+ return result ;
557616 }
558617
559- // We have to get these now, they disappear later. Also, if these attributes
560- // aren't set the contract of the applicability pipeline changed and that's
561- // a problem.
562- SmallVector<uint32_t > blockSizes;
563- SmallVector<uint32_t > gridSizes;
618+ // Extract block and grid sizes
564619 for (auto &fnName : kernelFuncNames) {
565620 auto tunedFunc = applicabilityCopy->lookupSymbol <func::FuncOp>(fnName);
566621 if (!tunedFunc) {
567- llvm::errs () << " Tuned copy somehow missing kernel function\n " ;
568- return failure ();
622+ result.status = CompilationStatus::CompilationFailed;
623+ compilationFailed.store (true , std::memory_order_relaxed);
624+ return result;
569625 }
570- blockSizes.push_back (
626+ result. blockSizes .push_back (
571627 tunedFunc->getAttrOfType <IntegerAttr>(" block_size" ).getInt ());
572- gridSizes.push_back (
628+ result. gridSizes .push_back (
573629 tunedFunc->getAttrOfType <IntegerAttr>(" grid_size" ).getInt ());
574630 }
575631
576- OwningOpRef<ModuleOp> compileCopy = copyIR (source, perfConfigAttr);
577-
578- // NOTE: Call to run() resets the cl opts
579- if (failed (compilation.run (compileCopy.get ()))) {
580- llvm::errs () << " Backend pipeline failed for config: " << perfConfig
581- << " \n " ;
582- return failure ();
632+ // Compilation
633+ OwningOpRef<ModuleOp> compileCopy =
634+ copyIRThread (threadSource.get (), perfConfigAttr);
635+ if (failed (threadCompilation.run (compileCopy.get ()))) {
636+ std::lock_guard<std::mutex> lock (outputMutex);
637+ llvm::errs () << " Backend pipeline failed for config: "
638+ << result.perfConfig << " \n " ;
639+ result.status = CompilationStatus::CompilationFailed;
640+ compilationFailed.store (true , std::memory_order_relaxed);
641+ return result;
583642 }
584643
585- // Extract binary and benchmark
586- SmallVector<std::string> hipModules;
644+ // Extract binaries
587645 for (const auto &fnName : kernelFuncNames) {
588646 auto binary =
589647 compileCopy->lookupSymbol <gpu::BinaryOp>(fnName + " _module" );
590648 if (!binary) {
591- llvm::errs () << " could not find the GPU binary\n " ;
649+ result.status = CompilationStatus::CompilationFailed;
650+ compilationFailed.store (true , std::memory_order_relaxed);
651+ return result;
592652 }
593- hipModules.push_back (cast<gpu::ObjectAttr>(binary.getObjects ()[0 ])
594- .getObject ()
595- .getValue ()
596- .str ());
653+ result. hipModules .push_back (cast<gpu::ObjectAttr>(binary.getObjects ()[0 ])
654+ .getObject ()
655+ .getValue ()
656+ .str ());
597657 }
598658
659+ result.status = CompilationStatus::Success;
660+ return result;
661+ };
662+
663+ // Launch parallel compilation tasks with dynamic work stealing
664+ // Note: We use atomic counter instead of static partitioning because
665+ // compilation times vary dramatically between configs (NotApplicable is fast,
666+ // full compilation is slow). Dynamic work stealing provides better load
667+ // balancing by allowing fast threads to pick up more work.
668+ {
669+ std::atomic<size_t > nextIdx{0 };
670+
671+ // Thread pool with work stealing pattern
672+ auto worker = [&]() {
673+ while (true ) {
674+ // Check if any compilation has failed (relaxed: just an optimization
675+ // hint)
676+ if (compilationFailed.load (std::memory_order_relaxed))
677+ break ;
678+
679+ size_t idx = nextIdx.fetch_add (1 , std::memory_order_relaxed);
680+ if (idx >= configs.size ())
681+ break ;
682+
683+ compilationResults[idx] = compileConfig (idx);
684+ }
685+ };
686+
687+ std::vector<std::thread> threads;
688+ for (unsigned i = 0 ; i < numThreads; ++i) {
689+ threads.emplace_back (worker);
690+ }
691+
692+ for (auto &t : threads) {
693+ t.join ();
694+ }
695+ }
696+
697+ // Check if any compilation failed and terminate early
698+ if (compilationFailed.load (std::memory_order_relaxed)) {
699+ llvm::errs ()
700+ << " Compilation failed for one or more configs. Terminating.\n " ;
701+ return failure ();
702+ }
703+
704+ // Sequential benchmarking phase (must be sequential for accurate timing)
705+ // Note: Due to early exit on compilation failures, only NotApplicable and
706+ // Success statuses are possible here.
707+ for (const auto &result : compilationResults) {
708+ llvm::outs () << result.perfConfig << " \t " ;
709+
710+ if (result.status == CompilationStatus::NotApplicable) {
711+ llvm::outs () << " N/A\n " ;
712+ continue ;
713+ }
714+
715+ // At this point, status must be Success (we exited early on any failures)
716+ assert (result.status == CompilationStatus::Success &&
717+ " Unexpected compilation status in benchmarking phase" );
718+
599719 FailureOr<double > timing = benchmarkKernels (
600- hipModules, kernelFuncNames, blockSizes, gridSizes, dataType,
601- hostBuffers, gpuBuffers, bufferLengths, benchmarkParams);
720+ result.hipModules , kernelFuncNames, result.blockSizes , result.gridSizes ,
721+ dataType, hostBuffers, gpuBuffers, bufferLengths, benchmarkParams);
722+
602723 if (failed (timing)) {
603724 llvm::errs () << " Kernel execution failed\n " ;
604725 return failure ();
0 commit comments