Skip to content

Commit e4ab0c1

Browse files
authored
add multi threaded compilation for rocmlir-tuning-driver (#2071)
* add multi threaded compilation for rocmlir-tuning-driver
1 parent 1752450 commit e4ab0c1

File tree

1 file changed

+185
-64
lines changed

1 file changed

+185
-64
lines changed

mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp

Lines changed: 185 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@
4646
#include "llvm/Support/InitLLVM.h"
4747
#include "llvm/Support/SourceMgr.h"
4848

49+
#include <atomic>
50+
#include <cassert>
4951
#include <chrono>
5052
#include <cstdlib>
53+
#include <future>
54+
#include <mutex>
5155
#include <thread>
5256

5357
// Utilities to allocate buffers
@@ -129,6 +133,11 @@ static llvm::cl::opt<std::string> benchmarkConfig(
129133
"Run benchmark with specific perf config only (skip tuning)"),
130134
llvm::cl::value_desc("perf config string"), llvm::cl::init(""));
131135

136+
static llvm::cl::opt<unsigned> numCompileThreads(
137+
"num-compile-threads",
138+
llvm::cl::desc("Number of parallel compilation threads (0 = auto)"),
139+
llvm::cl::value_desc("thread count"), llvm::cl::init(0));
140+
132141
// Ripped out of JitRunner.cpp
133142
static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
134143
MLIRContext *context) {
@@ -255,6 +264,20 @@ struct BenchmarkParams {
255264
bool showStats;
256265
};
257266

267+
enum class CompilationStatus {
268+
NotApplicable, // Config not applicable for this kernel
269+
CompilationFailed, // Config applicable but compilation failed
270+
Success // Successfully compiled
271+
};
272+
273+
struct CompilationResult {
274+
SmallString<64> perfConfig;
275+
CompilationStatus status = CompilationStatus::NotApplicable;
276+
SmallVector<std::string> hipModules;
277+
SmallVector<uint32_t> blockSizes;
278+
SmallVector<uint32_t> gridSizes;
279+
};
280+
258281
// In order to match rocprof, returns time in nanoseconds
259282
static FailureOr<double>
260283
benchmarkKernels(ArrayRef<std::string> binaries,
@@ -449,22 +472,16 @@ static LogicalResult runTuningLoop(ModuleOp source) {
449472
bufferLengths.push_back(sizeInBits / 8);
450473
}
451474

452-
// 2. Set up pipelines. Do this only once to save on construction cost.
453-
MLIRContext *ctx = source->getContext();
454-
PassManager applicability(source->getName(), PassManager::Nesting::Implicit);
455-
PassManager compilation(source->getName(), PassManager::Nesting::Implicit);
456-
475+
// 2. Set up compilation options (shared across all threads)
457476
rock::KernelOptions applicabilityOpts;
458477
applicabilityOpts.enableApplicability = true;
459478
applicabilityOpts.enableFusion = true;
460479
applicabilityOpts.tuningFallback = false;
461-
rock::buildKernelPipeline(applicability, applicabilityOpts);
462480

463481
rock::KernelOptions compilationKernOpts;
464482
compilationKernOpts.enableApplicability = false;
465483
compilationKernOpts.enableFusion = true;
466484
compilationKernOpts.tuningFallback = false;
467-
rock::buildKernelPipeline(compilation, compilationKernOpts);
468485

469486
RocmDeviceName deviceName;
470487
StringRef archName =
@@ -478,12 +495,6 @@ static LogicalResult runTuningLoop(ModuleOp source) {
478495
backendOpts.features = backendFeatures;
479496
backendOpts.optLevel = 3;
480497
backendOpts.suppressDiagnostic = true;
481-
rock::buildBackendPipeline(compilation, backendOpts);
482-
483-
// Now that we're in the kernel execution zone, turn off error messages
484-
// Register a handler that swallows all diagnostic print
485-
DiagnosticEngine &engine = ctx->getDiagEngine();
486-
engine.registerHandler([](Diagnostic &diag) {});
487498

488499
// 3. Initialize host buffers and allocate device buffers
489500
std::vector<void *> hostBuffers;
@@ -498,20 +509,7 @@ static LogicalResult runTuningLoop(ModuleOp source) {
498509
gpuBuffers.push_back(gpuBuffer);
499510
}
500511

501-
auto copyIR = [&](ModuleOp source,
502-
StringAttr perfConfigAttr) -> OwningOpRef<ModuleOp> {
503-
OwningOpRef<ModuleOp> copy = cast<ModuleOp>(source->clone());
504-
505-
copy->walk([&perfConfigAttr](rock::RockGemmWrapperInterface op) {
506-
op->setAttr("perf_config", perfConfigAttr);
507-
});
508-
copy->walk([&perfConfigAttr](rock::RockGemmGemmWrapperInterface op) {
509-
op->setAttr("perf_config", perfConfigAttr);
510-
});
511-
return copy;
512-
};
513-
514-
// 4. Actually tune
512+
// 4. Collect perf configs to compile
515513
std::vector<SmallString<64>> configs;
516514
if (!benchmarkConfig.empty()) {
517515
// Benchmark mode - just one config
@@ -540,65 +538,188 @@ static LogicalResult runTuningLoop(ModuleOp source) {
540538
useMedian, trimPercent,
541539
sleepUs, showStats};
542540

543-
for (const auto &perfConfig : configs) {
544-
llvm::outs() << perfConfig << "\t";
545-
OwningOpRef<ModuleOp> tuneCopy = cast<ModuleOp>(source->clone());
546-
StringAttr perfConfigAttr = StringAttr::get(ctx, perfConfig);
547-
548-
OwningOpRef<ModuleOp> applicabilityCopy = copyIR(source, perfConfigAttr);
549-
if (!rock::isModuleFusible(applicabilityCopy.get(), perfConfig)) {
550-
llvm::outs() << "N/A\n";
551-
continue;
541+
// Determine number of parallel threads
542+
unsigned numThreads = (numCompileThreads > 0)
543+
? numCompileThreads
544+
: std::thread::hardware_concurrency();
545+
if (numThreads == 0)
546+
numThreads = 4; // fallback
547+
548+
// Don't create more threads than configs to compile
549+
numThreads = std::min(numThreads, static_cast<unsigned>(configs.size()));
550+
551+
// Serialize source module once (shared by all threads for cloning)
552+
std::string sourceModuleStr;
553+
llvm::raw_string_ostream sourceOs(sourceModuleStr);
554+
source->print(sourceOs);
555+
sourceOs.flush();
556+
557+
// Parallel compilation phase
558+
std::vector<CompilationResult> compilationResults(configs.size());
559+
std::mutex outputMutex; // For thread-safe console output
560+
std::atomic<bool> compilationFailed{
561+
false}; // Flag to signal early termination
562+
563+
auto compileConfig = [&](size_t idx) -> CompilationResult {
564+
CompilationResult result;
565+
result.perfConfig = configs[idx];
566+
567+
// Each thread needs its own context and pass managers for thread-safety
568+
DialectRegistry threadRegistry;
569+
registerRocMLIRDialects(threadRegistry);
570+
MLIRContext threadCtx(threadRegistry);
571+
threadCtx.getDiagEngine().registerHandler([](Diagnostic &diag) {});
572+
573+
// Parse the serialized module in this thread's context
574+
OwningOpRef<ModuleOp> threadSource =
575+
parseSourceString<ModuleOp>(sourceModuleStr, &threadCtx);
576+
if (!threadSource)
577+
return result;
578+
579+
// Set up pipelines for this thread
580+
PassManager threadApplicability(&threadCtx,
581+
PassManager::getAnyOpAnchorName(),
582+
PassManager::Nesting::Implicit);
583+
PassManager threadCompilation(&threadCtx, PassManager::getAnyOpAnchorName(),
584+
PassManager::Nesting::Implicit);
585+
586+
rock::buildKernelPipeline(threadApplicability, applicabilityOpts);
587+
rock::buildKernelPipeline(threadCompilation, compilationKernOpts);
588+
rock::buildBackendPipeline(threadCompilation, backendOpts);
589+
590+
StringAttr perfConfigAttr = StringAttr::get(&threadCtx, result.perfConfig);
591+
592+
// Helper to copy IR with perf config set
593+
auto copyIRThread = [&](ModuleOp src,
594+
StringAttr attr) -> OwningOpRef<ModuleOp> {
595+
OwningOpRef<ModuleOp> copy = cast<ModuleOp>(src->clone());
596+
copy->walk([&attr](rock::RockGemmWrapperInterface op) {
597+
op->setAttr("perf_config", attr);
598+
});
599+
copy->walk([&attr](rock::RockGemmGemmWrapperInterface op) {
600+
op->setAttr("perf_config", attr);
601+
});
602+
return copy;
603+
};
604+
605+
// Applicability check
606+
OwningOpRef<ModuleOp> applicabilityCopy =
607+
copyIRThread(threadSource.get(), perfConfigAttr);
608+
if (!rock::isModuleFusible(applicabilityCopy.get(), result.perfConfig)) {
609+
result.status = CompilationStatus::NotApplicable;
610+
return result;
552611
}
553612

554-
if (failed(applicability.run(applicabilityCopy.get()))) {
555-
llvm::outs() << "N/A\n";
556-
continue;
613+
if (failed(threadApplicability.run(applicabilityCopy.get()))) {
614+
result.status = CompilationStatus::NotApplicable;
615+
return result;
557616
}
558617

559-
// We have to get these now, they disappear later. Also, if these attributes
560-
// aren't set the contract of the applicability pipeline changed and that's
561-
// a problem.
562-
SmallVector<uint32_t> blockSizes;
563-
SmallVector<uint32_t> gridSizes;
618+
// Extract block and grid sizes
564619
for (auto &fnName : kernelFuncNames) {
565620
auto tunedFunc = applicabilityCopy->lookupSymbol<func::FuncOp>(fnName);
566621
if (!tunedFunc) {
567-
llvm::errs() << "Tuned copy somehow missing kernel function\n";
568-
return failure();
622+
result.status = CompilationStatus::CompilationFailed;
623+
compilationFailed.store(true, std::memory_order_relaxed);
624+
return result;
569625
}
570-
blockSizes.push_back(
626+
result.blockSizes.push_back(
571627
tunedFunc->getAttrOfType<IntegerAttr>("block_size").getInt());
572-
gridSizes.push_back(
628+
result.gridSizes.push_back(
573629
tunedFunc->getAttrOfType<IntegerAttr>("grid_size").getInt());
574630
}
575631

576-
OwningOpRef<ModuleOp> compileCopy = copyIR(source, perfConfigAttr);
577-
578-
// NOTE: Call to run() resets the cl opts
579-
if (failed(compilation.run(compileCopy.get()))) {
580-
llvm::errs() << "Backend pipeline failed for config: " << perfConfig
581-
<< "\n";
582-
return failure();
632+
// Compilation
633+
OwningOpRef<ModuleOp> compileCopy =
634+
copyIRThread(threadSource.get(), perfConfigAttr);
635+
if (failed(threadCompilation.run(compileCopy.get()))) {
636+
std::lock_guard<std::mutex> lock(outputMutex);
637+
llvm::errs() << "Backend pipeline failed for config: "
638+
<< result.perfConfig << "\n";
639+
result.status = CompilationStatus::CompilationFailed;
640+
compilationFailed.store(true, std::memory_order_relaxed);
641+
return result;
583642
}
584643

585-
// Extract binary and benchmark
586-
SmallVector<std::string> hipModules;
644+
// Extract binaries
587645
for (const auto &fnName : kernelFuncNames) {
588646
auto binary =
589647
compileCopy->lookupSymbol<gpu::BinaryOp>(fnName + "_module");
590648
if (!binary) {
591-
llvm::errs() << "could not find the GPU binary\n";
649+
result.status = CompilationStatus::CompilationFailed;
650+
compilationFailed.store(true, std::memory_order_relaxed);
651+
return result;
592652
}
593-
hipModules.push_back(cast<gpu::ObjectAttr>(binary.getObjects()[0])
594-
.getObject()
595-
.getValue()
596-
.str());
653+
result.hipModules.push_back(cast<gpu::ObjectAttr>(binary.getObjects()[0])
654+
.getObject()
655+
.getValue()
656+
.str());
597657
}
598658

659+
result.status = CompilationStatus::Success;
660+
return result;
661+
};
662+
663+
// Launch parallel compilation tasks with dynamic work stealing
664+
// Note: We use atomic counter instead of static partitioning because
665+
// compilation times vary dramatically between configs (NotApplicable is fast,
666+
// full compilation is slow). Dynamic work stealing provides better load
667+
// balancing by allowing fast threads to pick up more work.
668+
{
669+
std::atomic<size_t> nextIdx{0};
670+
671+
// Thread pool with work stealing pattern
672+
auto worker = [&]() {
673+
while (true) {
674+
// Check if any compilation has failed (relaxed: just an optimization
675+
// hint)
676+
if (compilationFailed.load(std::memory_order_relaxed))
677+
break;
678+
679+
size_t idx = nextIdx.fetch_add(1, std::memory_order_relaxed);
680+
if (idx >= configs.size())
681+
break;
682+
683+
compilationResults[idx] = compileConfig(idx);
684+
}
685+
};
686+
687+
std::vector<std::thread> threads;
688+
for (unsigned i = 0; i < numThreads; ++i) {
689+
threads.emplace_back(worker);
690+
}
691+
692+
for (auto &t : threads) {
693+
t.join();
694+
}
695+
}
696+
697+
// Check if any compilation failed and terminate early
698+
if (compilationFailed.load(std::memory_order_relaxed)) {
699+
llvm::errs()
700+
<< "Compilation failed for one or more configs. Terminating.\n";
701+
return failure();
702+
}
703+
704+
// Sequential benchmarking phase (must be sequential for accurate timing)
705+
// Note: Due to early exit on compilation failures, only NotApplicable and
706+
// Success statuses are possible here.
707+
for (const auto &result : compilationResults) {
708+
llvm::outs() << result.perfConfig << "\t";
709+
710+
if (result.status == CompilationStatus::NotApplicable) {
711+
llvm::outs() << "N/A\n";
712+
continue;
713+
}
714+
715+
// At this point, status must be Success (we exited early on any failures)
716+
assert(result.status == CompilationStatus::Success &&
717+
"Unexpected compilation status in benchmarking phase");
718+
599719
FailureOr<double> timing = benchmarkKernels(
600-
hipModules, kernelFuncNames, blockSizes, gridSizes, dataType,
601-
hostBuffers, gpuBuffers, bufferLengths, benchmarkParams);
720+
result.hipModules, kernelFuncNames, result.blockSizes, result.gridSizes,
721+
dataType, hostBuffers, gpuBuffers, bufferLengths, benchmarkParams);
722+
602723
if (failed(timing)) {
603724
llvm::errs() << "Kernel execution failed\n";
604725
return failure();

0 commit comments

Comments
 (0)